diff --git a/.github/actions/rust-toolchain-setup/action.yml b/.github/actions/rust-toolchain-setup/action.yml new file mode 100644 index 0000000000000..bf73fede16c7f --- /dev/null +++ b/.github/actions/rust-toolchain-setup/action.yml @@ -0,0 +1,44 @@ +# yaml-language-server: $schema=https://json.schemastore.org/github-action.json + +name: 'Rust toolchain setup' +description: 'Common setup steps for GitHub workflows for Rust projects' + +runs: + using: composite + steps: + - uses: dtolnay/rust-toolchain@1.71.0 + with: + components: clippy, rustfmt + - uses: extractions/setup-just@v1 + with: + just-version: '1.15.0' # optional semver specification, otherwise latest + + ### + ### Linux setup + ### + - name: rustup + # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. + if: ${{ (runner.os == 'Linux') }} + run: | + rustup set profile minimal + rustup install + shell: bash + # - name: Cargo login + # if: ${{ (runner.os == 'Linux') }} + # run: just cargo-login-ci + # shell: bash + + ### + ### Windows setup + ### + - name: rustup + # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. + if: ${{ (runner.os == 'Windows') }} + run: | + rustup set profile minimal + rustup install + shell: pwsh + # - name: Cargo login + # if: ${{ (runner.os == 'Windows') }} + # run: just cargo-login-ci-windows + # shell: pwsh diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index d89f0cdd91e52..0000000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,22 +0,0 @@ -# Number of days of inactivity before an issue becomes stale -daysUntilStale: 30 - -# Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 - -# Issues with these labels will never be considered stale -exemptLabels: - - contributions welcome - - feature request - - regression - -# Label to use when marking an issue as stale -staleLabel: stale - -# Comment to post when marking an issue as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. - -# Comment to post when closing a stale issue. Set to `false` to disable -closeComment: > - This issue has been automatically closed due to inactivity. Please reactivate if further support is needed. diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml new file mode 100644 index 0000000000000..6c3f2eb0fbbe1 --- /dev/null +++ b/.github/workflows/rust-ci.yml @@ -0,0 +1,132 @@ +name: Rust + +on: [pull_request] + +env: + CARGO_TERM_COLOR: always + RUST_LOG: onnxruntime=debug,onnxruntime-sys=debug + RUST_BACKTRACE: 1 + MANIFEST_PATH: ${{ github.workspace }}/rust/Cargo.toml + +jobs: + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust-toolchain-setup + - name: vendor onnxruntime source + run: just vendor + - name: fmt + run: cargo fmt --all -- --check + + download: + name: Download prebuilt ONNX Runtime archive from build.rs + runs-on: ubuntu-latest + env: + ORT_RUST_STRATEGY=download + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust-toolchain-setup + - run: rustup target install x86_64-unknown-linux-gnu + - run: rustup target install x86_64-apple-darwin + - run: rustup target install i686-pc-windows-msvc + - run: rustup target install x86_64-pc-windows-msvc + # ****************************************************************** + - name: Download prebuilt archive (CPU, x86_64-unknown-linux-gnu) + run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (CPU, x86_64-unknown-linux-gnu) + run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-1.*.tgz + # ****************************************************************** + - name: Download prebuilt archive (CPU, x86_64-apple-darwin) + run: cargo build --target x86_64-apple-darwin --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (CPU, x86_64-apple-darwin) + run: ls -lh target/x86_64-apple-darwin/debug/build/onnxruntime-sys-*/out/onnxruntime-osx-x64-1.*.tgz + # ****************************************************************** + - name: Download prebuilt archive (CPU, i686-pc-windows-msvc) + run: cargo build --target i686-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (CPU, i686-pc-windows-msvc) + run: ls -lh target/i686-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x86-1.*.zip + # ****************************************************************** + - name: Download prebuilt archive (CPU, x86_64-pc-windows-msvc) + run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (CPU, x86_64-pc-windows-msvc) + run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x64-1.*.zip + # ****************************************************************** + - name: Download prebuilt archive (GPU, x86_64-unknown-linux-gnu) + env: + ORT_USE_CUDA: "yes" + run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (GPU, x86_64-unknown-linux-gnu) + run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-gpu-1.*.tgz + # ****************************************************************** + - name: Download prebuilt archive (GPU, x86_64-pc-windows-msvc) + env: + ORT_USE_CUDA: "yes" + run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} + - name: Verify prebuilt archive downloaded (GPU, x86_64-pc-windows-msvc) + run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-gpu-x64-1.*.zip + + test: + name: Test Suite + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + target: + [ + x86_64-unknown-linux-gnu, + x86_64-apple-darwin, + x86_64-pc-windows-msvc, + i686-pc-windows-msvc, + ] + include: + - target: x86_64-unknown-linux-gnu + os: ubuntu-latest + - target: x86_64-apple-darwin + os: macos-latest + - target: x86_64-pc-windows-msvc + os: windows-latest + - target: i686-pc-windows-msvc + os: windows-latest + env: + CARGO_BUILD_TARGET: ${{ matrix.target }} + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust-toolchain-setup + - name: vendor onnxruntime source + run: just vendor + - run: rustup target install ${{ matrix.target }} + - name: Install additional packages (macOS) + if: contains(matrix.target, 'x86_64-apple-darwin') + run: brew install libomp + - name: Build (cargo build) + run: cargo build --all --manifest-path ${{ env.MANIFEST_PATH }} + - name: Build tests (cargo test) + run: cargo test --no-run --manifest-path ${{ env.MANIFEST_PATH }} + - name: Build onnxruntime with 'model-fetching' feature + run: cargo build --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching + - name: Test onnxruntime-sys + run: cargo build --package onnxruntime-sys -- --test-threads=1 --nocapture + - name: Test onnxruntime + run: cargo test --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching -- --test-threads=1 --nocapture + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust-toolchain-setup + - name: vendor onnxruntime source + run: just vendor + - run: clippy --all-features --manifest-path ${{ env.MANIFEST_PATH }} -- -D warnings + + package-sys: + name: Package onnxruntime-sys + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust-toolchain-setup + - name: vendor onnxruntime source + run: just vendor + - run: cargo package --allow-dirty --package onnxruntime-sys diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000..95607f297c6bd --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,39 @@ +name: Close stale issues +on: + # Allows you to dictate when you want this workflow to run using cron syntax (times in UTC) + schedule: + - cron: "0 15 * * *" + # Allows you to run this workflow manually from the Actions tab + # workflow_dispatch: + +jobs: + close-stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v8.0.0 + with: + # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: contributions welcome, feature request, regression + # Override exempt-all-assignees but only to exempt the issues with an assignee to be marked as stale automatically + exempt-all-issue-assignees: true + # Used to ignore the issues and pull requests created before the start date + # Start date should be April 19, 2022 - corresponds to the day previous stale bot stopped working + start-date: '2022-04-19T00:00:00Z' + # Number of days without activity before the actions/stale action labels an issue + days-before-issue-stale: 30 + # Number of days without activity before the actions/stale action closes an issue + days-before-issue-close: 30 + # Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale + stale-issue-label: "stale" + # Comment that you want to add to issues that are labeled by the actions/stale action + stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details." + # Comment that you want to add to issues that are closed by the actions/stale action + close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed." + # If you never want this action to label PRs, set this value to -1 + days-before-pr-stale: -1 + # If you never want this action to close PRs, set this value to -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.lintrunner.toml b/.lintrunner.toml index 86be8d0d0bd38..4e5d077b08ff4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -45,6 +45,7 @@ exclude_patterns = [ 'cmake/external/**', # ignore generated flatbuffers code 'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', @@ -76,6 +77,7 @@ exclude_patterns = [ 'cmake/**', 'orttraining/*', 'onnxruntime/core/flatbuffers/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', @@ -97,33 +99,6 @@ init_command = [ ] is_formatter = true -[[linter]] -code = 'PYLINT' -include_patterns = [ - # TODO: Opt in to pylint by adding paths here -] -exclude_patterns = [ -] -command = [ - 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'pylint_linter', - '--rcfile=pyproject.toml', - '--', - '@{{PATHSFILE}}' -] -init_command = [ - 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'pip_init', - '--dry-run={{DRYRUN}}', - '--requirement=requirements-lintrunner.txt', -] - [[linter]] code = 'RUSTFMT' include_patterns = ['**/*.rs'] diff --git a/.vscode/settings.json b/.vscode/settings.json index b7a1292efb2c6..2f2adc78f6de9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,30 +13,13 @@ "editor.codeActionsOnSave": { "source.organizeImports": true }, + "editor.defaultFormatter": "ms-python.black-formatter" }, // Enable Python linting and Pylance type checking "python.analysis.typeCheckingMode": "basic", - "python.formatting.provider": "black", - "python.formatting.blackArgs": [ - "--line-length", - "120" - ], - "python.sortImports.args": [ - "--profile", - "black", - "--line-length", - "120" - ], - "python.linting.enabled": true, - "python.linting.flake8Enabled": true, - "python.linting.pylintEnabled": true, - "python.linting.pydocstyleEnabled": true, - "python.linting.pydocstyleArgs": [ - "--convention=google" - ], - "python.linting.banditEnabled": true, "cpplint.lineLength": 120, "cpplint.filters": [ "-build/include_subdir", "-runtime/references" ] +} diff --git a/Package.swift b/Package.swift deleted file mode 100644 index f8bf33001ea24..0000000000000 --- a/Package.swift +++ /dev/null @@ -1,109 +0,0 @@ -// swift-tools-version: 5.6 -// The swift-tools-version declares the minimum version of Swift required to build this package and MUST be the first -// line of this file. 5.6 is required to support zip files for the pod archive binaryTarget. -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// -// A user of the Swift Package Manager (SPM) package will consume this file directly from the ORT github repository. -// For context, the end user's config will look something like: -// -// dependencies: [ -// .package(url: "https://github.com/microsoft/onnxruntime", branch: "rel-1.15.0"), -// ... -// ], -// -// NOTE: The direct consumption creates a somewhat complicated setup to 'release' a new version of the ORT SPM package. -// TBD: how to manage the release process - -import PackageDescription -import class Foundation.ProcessInfo - -let package = Package( - name: "onnxruntime", - platforms: [.iOS(.v12)], - products: [ - .library(name: "onnxruntime", - type: .static, - targets: ["OnnxRuntimeBindings"]), - ], - dependencies: [], - targets: [ - .target(name: "OnnxRuntimeBindings", - dependencies: ["onnxruntime"], - path: "objectivec", - exclude: ["test", "docs", "ReadMe.md", "format_objc.sh", - "ort_checkpoint.mm", - "ort_checkpoint_internal.h", - "ort_training_session_internal.h", - "ort_training_session.mm", - "include/ort_checkpoint.h", - "include/ort_training_session.h", - "include/onnxruntime_training.h"], - cxxSettings: [ - .define("SPM_BUILD"), - .unsafeFlags(["-std=c++17", - "-fobjc-arc-exceptions" - ]), - ], linkerSettings: [ - .unsafeFlags(["-ObjC"]), - ]), - .testTarget(name: "OnnxRuntimeBindingsTests", - dependencies: ["OnnxRuntimeBindings"], - path: "swift/OnnxRuntimeBindingsTests", - resources: [ - .copy("Resources/single_add.basic.ort") - ]), - ] -) - -// Add the ORT iOS Pod archive as a binary target. -// -// There are 2 scenarios: -// -// Release branch of ORT github repo: -// Target will be set to the released pod archive and its checksum. -// -// Any other branch/tag of ORT github repo: -// Invalid by default. We do not have a pod archive that is guaranteed to work -// as the objective-c bindings may have changed since the pod archive was released. - -// CI or local testing where you have built/obtained the iOS Pod archive matching the current source code. -// Requires the ORT_IOS_POD_LOCAL_PATH environment variable to be set to specify the location of the pod. -if let pod_archive_path = ProcessInfo.processInfo.environment["ORT_IOS_POD_LOCAL_PATH"] { - // ORT_IOS_POD_LOCAL_PATH MUST be a path that is relative to Package.swift. - // - // To build locally, tools/ci_build/github/apple/build_and_assemble_ios_pods.py can be used - // See https://onnxruntime.ai/docs/build/custom.html#ios - // Example command: - // python3 tools/ci_build/github/apple/build_and_assemble_ios_pods.py \ - // --variant Full \ - // --build-settings-file tools/ci_build/github/apple/default_full_ios_framework_build_settings.json - // - // This should produce the pod archive in build/ios_pod_staging, and ORT_IOS_POD_LOCAL_PATH can be set to - // "build/ios_pod_staging/pod-archive-onnxruntime-c-???.zip" where '???' is replaced by the version info in the - // actual filename. - package.targets.append(Target.binaryTarget(name: "onnxruntime", path: pod_archive_path)) - -} else { - // When creating the release version: - // - remove the fatalError - // - uncomment the package.targets.append call - // - update the major/minor/patch version info in the url - // - insert the checksum info from the onnxruntime-ios-packaging-pipeline CI's 'Print ORT iOS Pod checksum' - // stage output (or download the pod archive artifact from the CI and run `shasum -a 256 ` - // to manually calculate it). - // The checksum length and chars should look something like - // "c89cd106ff02eb3892243acd7c4f2bd8e68c2c94f2751b5e35f98722e10c042b" - // - // package.targets.append( - // Target.binaryTarget(name: "onnxruntime", - // url: "https://onnxruntimepackages.z14.web.core.windows.net/pod-archive-onnxruntime-c-.zip", - // checksum: "Insert checksum here") - // ) - - fatalError("It is not valid to use a non-release branch from https://github.com/microsoft/onnxruntime.\n" + - "Please use a release branch (e.g. rel-1.15.0), or build the ONNX Runtime iOS pod archive locally " + - "and set the ORT_IOS_POD_LOCAL_PATH environment variable.\n" + - "See Package.swift for more information on using a local pod archive.") -} diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 2a3de3bb0ee51..e8dbc9cf9eff6 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -568,7 +568,7 @@ "component": { "type": "git", "git": { - "commitHash": "d10b27fe37736d2944630ecd7557cefa95cf87c9", + "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" } } diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index a9eaacc6f2938..81181d3ccfb20 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -90,55 +90,6 @@ def add_github_dep(name, parsed_url): git_deps[dep] = name -with open( - os.path.join(REPO_DIR, "tools", "ci_build", "github", "linux", "docker", "Dockerfile.manylinux2_28_cuda11"), -) as f: - for line in f: - if not line.strip(): - package_name = None - package_filename = None - package_url = None - if package_filename is None: - m = re.match(r"RUN\s+export\s+(.+?)_ROOT=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - else: - m = re.match(r"RUN\s+export\s+(.+?)_VERSION=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - elif package_url is None: - m = re.match(r"(.+?)_DOWNLOAD_URL=(\S+)", line) - if m is not None: - package_url = m.group(2) - if package_name == "LIBXCRYPT": - package_url = m.group(2) + "/v" + package_filename + ".tar.gz" - elif package_name == "CMAKE": - package_url = m.group(2) + "/v" + package_filename + "/cmake-" + package_filename + ".tar.gz" - else: - package_url = m.group(2) + "/" + package_filename + ".tar.gz" - parsed_url = urlparse(package_url) - if parsed_url.hostname == "github.com": - add_github_dep("manylinux dependency " + package_name, parsed_url) - else: - registration = { - "Component": { - "Type": "other", - "other": { - "Name": package_name.lower(), - "Version": package_filename.split("-")[-1], - "DownloadUrl": package_url, - }, - "comments": "manylinux dependency", - } - } - registrations.append(registration) - package_name = None - package_filename = None - package_url = None - - def normalize_path_separators(path): return path.replace(os.path.sep, "/") diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 6f1ca84e1a304..12fbb291c3a70 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,82 +2,6 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ - { - "Component": { - "Type": "other", - "other": { - "Name": "autoconf", - "Version": "2.71", - "DownloadUrl": "http://ftp.gnu.org/gnu/autoconf/autoconf-2.71.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "automake", - "Version": "1.16.5", - "DownloadUrl": "http://ftp.gnu.org/gnu/automake/automake-1.16.5.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "libtool", - "Version": "2.4.7", - "DownloadUrl": "http://ftp.gnu.org/gnu/libtool/libtool-2.4.7.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "git", - "Version": "2.36.2", - "DownloadUrl": "https://www.kernel.org/pub/software/scm/git/git-2.36.2.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "sqlite_autoconf", - "Version": "3390200", - "DownloadUrl": "https://www.sqlite.org/2022/sqlite-autoconf-3390200.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "openssl", - "Version": "1.1.1q", - "DownloadUrl": "https://www.openssl.org/source/openssl-1.1.1q.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "50cf2b6dd4fdf04309445f2eec8de7051d953abf", - "repositoryUrl": "https://github.com/besser82/libxcrypt.git" - }, - "comments": "manylinux dependency LIBXCRYPT" - } - }, { "component": { "type": "git", @@ -102,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "e2525550194ce3d8a2c4a3af451c9d9b3ae6650e", + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -212,7 +136,7 @@ "component": { "type": "git", "git": { - "commitHash": "003c580e696a774afdc984996ee909b7c8d8128c", + "commitHash": "0da379fc4808f9601faef392352018c741c0f297", "repositoryUrl": "https://github.com/google/XNNPACK.git" }, "comments": "googlexnnpack" @@ -272,7 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "0462dc31ae78f48744b6141ae376df1f96d3f459", + "commitHash": "a43ce67187bab219520fd80f21af8bbd4354bc8c", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" @@ -302,7 +226,7 @@ "component": { "type": "git", "git": { - "commitHash": "1787867f6183f056420e532eec640cba25efafea", + "commitHash": "4fe0e1e183925bf8cfa6aae24237e724a96479b8", "repositoryUrl": "https://github.com/Maratyszcza/pthreadpool.git" }, "comments": "pthreadpool" @@ -362,7 +286,7 @@ "component": { "type": "git", "git": { - "commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933", + "commitHash": "6f47420213f757831fae65c686aa471749fa8d60", "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" }, "comments": "cutlass" @@ -392,7 +316,7 @@ "component": { "type": "git", "git": { - "commitHash": "d52ec01652b7d620386251db92455968d8d90bdc", + "commitHash": "a4f72a314a85732ed67d5aa8d1088d207a7e0e61", "repositoryUrl": "https://github.com/ROCmSoftwarePlatform/composable_kernel.git" }, "comments": "composable_kernel" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cf7565869e446..d2c56e455be4c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -39,7 +39,12 @@ include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +# TODO: update this once all system adapt c++20 +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") +set(CMAKE_CXX_STANDARD 20) +else() set(CMAKE_CXX_STANDARD 17) +endif() set_property(GLOBAL PROPERTY USE_FOLDERS ON) # NOTE: POSITION INDEPENDENT CODE hurts performance, and it only make sense on POSIX systems @@ -68,6 +73,12 @@ option(onnxruntime_ENABLE_PYTHON "Enable python buildings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) +# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through +# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical +# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + +option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -103,9 +114,7 @@ option(onnxruntime_ENABLE_LTO "Enable link time optimization" OFF) option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) - -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) @@ -142,10 +151,11 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF) +cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) -option(onnxruntime_DISABLE_ABSEIL "Do not link to Abseil. Redefine Inlined containers to STD containers." OFF) +# Even when onnxruntime_DISABLE_ABSEIL is ON, ONNX Runtime still needs to link to abseil. +option(onnxruntime_DISABLE_ABSEIL "Do not use Abseil data structures in ONNX Runtime source code. Redefine Inlined containers to STD containers." OFF) option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF) @@ -265,10 +275,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_CUDA) - set(onnxruntime_DISABLE_RTTI OFF) -endif() - if (onnxruntime_USE_ROCM) if (WIN32) message(FATAL_ERROR "ROCM does not support build in Windows!") @@ -519,7 +525,21 @@ if(NOT WIN32 AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android") find_package(Iconv REQUIRED) set(ICONV_LIB Iconv::Iconv) endif() + find_package(Patch) +if (WIN32 AND NOT Patch_FOUND) + # work around CI machines missing patch from the git install by falling back to the binary in this repo. + # replicate what happens in https://github.com/Kitware/CMake/blob/master/Modules/FindPatch.cmake but without + # the hardcoded suffixes in the path to the patch binary. + find_program(Patch_EXECUTABLE NAMES patch PATHS ${PROJECT_SOURCE_DIR}/external/git.Win32.2.41.03.patch) + if(Patch_EXECUTABLE) + set(Patch_FOUND 1) + if (NOT TARGET Patch::patch) + add_executable(Patch::patch IMPORTED) + set_property(TARGET Patch::patch PROPERTY IMPORTED_LOCATION ${Patch_EXECUTABLE}) + endif() + endif() +endif() if(Patch_FOUND) message("Patch found: ${Patch_EXECUTABLE}") endif() @@ -665,6 +685,9 @@ set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) if (onnxruntime_USE_CUDA) + if (onnxruntime_USE_CUDA_NHWC_OPS) + add_compile_definitions(ENABLE_CUDA_NHWC_OPS) + endif() enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -1276,14 +1299,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1304,16 +1319,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") @@ -1474,29 +1479,32 @@ if (onnxruntime_ENABLE_TRAINING) list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) endif() -if (UNIX AND onnxruntime_USE_MPI) - if (EXISTS "${onnxruntime_MPI_HOME}") - set(MPI_HOME "${onnxruntime_MPI_HOME}") - elseif (EXISTS "/bert_ort/openmpi") - set(MPI_HOME "/bert_ort/openmpi") - endif() - - find_package(MPI) +if (UNIX AND onnxruntime_USE_NCCL) + # MPI is INDEPENDENT of NCCL for now. You can build NCLL without MPI and launch multi-GPU with your own launcher. + if (onnxruntime_USE_MPI) + if (EXISTS "${onnxruntime_MPI_HOME}") + set(MPI_HOME "${onnxruntime_MPI_HOME}") + elseif (EXISTS "/bert_ort/openmpi") + set(MPI_HOME "/bert_ort/openmpi") + endif() + find_package(MPI) - if (MPI_CXX_FOUND) - message( STATUS "MPI Version: ${MPI_CXX_VERSION}") - message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" ) - mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) - else () - message( - FATAL_ERROR - "MPI is not found. Please define onnxruntime_MPI_HOME to specify the path of MPI. Otherwise, NCCL will be disabled." - ) + if (MPI_CXX_FOUND) + message( STATUS "MPI Version: ${MPI_CXX_VERSION}") + message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" ) + mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) + else () + message( + FATAL_ERROR + "MPI is not found. Please define onnxruntime_MPI_HOME to specify the path of MPI. Otherwise, NCCL will be disabled." + "or you can remove --use_mpi from build args to disable MPI." + ) + endif() endif() - # Find NCCL and MPI - if (onnxruntime_USE_NCCL AND MPI_CXX_FOUND) + # Find NCCL + if (onnxruntime_USE_NCCL) if (onnxruntime_USE_CUDA) set(NCCL_LIBNAME "nccl") elseif (onnxruntime_USE_ROCM) diff --git a/cmake/deps.txt b/cmake/deps.txt index 279b5ca649dba..e065cacdfc423 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -3,30 +3,40 @@ #The columns are separated by ";" because a list in cmake is just a ";" separated group of strings. #Names should be in lower case. They will be used as variable names in cmake. #URLs can be either https URLs or local file paths in cmake-style(directory separator is a forward slash character). -#SHA1 hashes can be generated by running sha1sum command. +#SHA1 hashes can be generated by running sha1sum command on linux. PowerShell can also be used: +# (Get-FileHash -Algorithm SHA1 ).Hash.ToLower() #If you need to change abseil's version to a different one, you may also want to update external\abseil-cpp.natvis #since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would #not affect built binaries. +# +# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. +# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 +# abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 -eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb +# This Eigen commit id matches the eigen archive being consumed from https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip +# prior to the 3.4.1 RC changing the bits and invalidating the hash. +# it contains changes on top of 3.4.0 which are required to fix build issues. +# Until the 3.4.1 release this is the best option we have. +# Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 +eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc -googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c +googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/e2525550194ce3d8a2c4a3af451c9d9b3ae6650e.zip;782f23d788185887f520a90535513e244218e928 +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -35,13 +45,13 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/1787867f6183f056420e532eec640cba25efafea.zip;e43e80781560c5ab404a4da20f34d846f5f5d101 +pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip;07a0aa91dd9bf86f31b95497e00f31d8a261a4bd pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13 pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip;85da3caa60eb2b148613b443fbc2bfdc30689965 re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 \ No newline at end of file +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py new file mode 100644 index 0000000000000..d357284d91225 --- /dev/null +++ b/cmake/deps_update_and_upload.py @@ -0,0 +1,56 @@ +# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# Before running the script, increase the version number found at: +# https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload +# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +import re +import subprocess +import os +import argparse +import tempfile + +parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") +parser.add_argument( + "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" +) +parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") +parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") +args = parser.parse_args() + +with open("cmake/deps.txt") as file: + text = file.read() + +lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] + +root_path = args.root_path + +for line in lines: + url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) + filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) + full_path = os.path.join(root_path, filename) + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 + +package_name = "onnxruntime_build_dependencies" +version = args.version + +# Check if the user is logged in to Azure +result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 +if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True) # noqa: PLW1510 + +# Publish the package to Azure Artifacts if --no-upload is not specified + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) # noqa: PLW1510 +else: + print("would have run: " + cmd) + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) # noqa: PLW1510 +else: + print("would have run: " + cmd) diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index e923d5862ec2e..1e5a36fb9efb9 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -25,30 +25,55 @@ - empty - {{ size={size_} }} + + + + + + size={ _size() } + size=({_size()}) - size_ - capacity_ - + _size() + _capacity() + - size_ + _size() - - slots_[nslot] + + _slots()[nslot] nslot++ - + + + + + *($T1 *){value} + (*($T1 *){value}) + + *($T1 *){value} + + + + + + *($T1 *)this + (*($T1 *)this) + + *($T1 *)this + + + - {{ {value.first}:{value.second} }} + {value.first}, {value.second} + ({value.first}, {value.second}) - value.first - value.second + value.first + value.second diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index 7168cd1a22c53..b4e6c834c83ab 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -12,13 +12,14 @@ if(NOT composable_kernel_POPULATED) FetchContent_Populate(composable_kernel) set(BUILD_DEV OFF CACHE BOOL "Disable -Weverything, otherwise, error: 'constexpr' specifier is incompatible with C++98 [-Werror,-Wc++98-compat]" FORCE) # Exclude i8 device gemm instances due to excessive long compilation time and not being used - set(DTYPES fp32 fp16 bf16) + set(DTYPES fp32 fp16 bf16 fp8) set(INSTANCES_ONLY ON) add_subdirectory(${composable_kernel_SOURCE_DIR} ${composable_kernel_BINARY_DIR} EXCLUDE_FROM_ALL) add_library(onnxruntime_composable_kernel_includes INTERFACE) target_include_directories(onnxruntime_composable_kernel_includes INTERFACE ${composable_kernel_SOURCE_DIR}/include + ${composable_kernel_BINARY_DIR}/include ${composable_kernel_SOURCE_DIR}/library/include) target_compile_definitions(onnxruntime_composable_kernel_includes INTERFACE __fp32__ __fp16__ __bf16__) endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 8c5d81d638ced..983eecdd88235 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,7 +4,6 @@ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTIO cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 5dd6c8e8dfe84..b123adb624fa4 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,23 +1,14 @@ - if (onnxruntime_USE_PREINSTALLED_EIGEN) add_library(eigen INTERFACE) file(TO_CMAKE_PATH ${eigen_SOURCE_PATH} eigen_INCLUDE_DIRS) target_include_directories(eigen INTERFACE ${eigen_INCLUDE_DIRS}) else () - if (onnxruntime_USE_ACL AND (NOT onnxruntime_USE_ACL_2308)) - FetchContent_Declare( - eigen - URL ${DEP_URL_eigen} - URL_HASH SHA1=${DEP_SHA1_eigen} - PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-space-change --ignore-whitespace < ${PROJECT_SOURCE_DIR}/patches/eigen/Fix_Eigen_Build_Break.patch - ) - else() - FetchContent_Declare( - eigen - URL ${DEP_URL_eigen} - URL_HASH SHA1=${DEP_SHA1_eigen} - ) - endif() + FetchContent_Declare( + eigen + URL ${DEP_URL_eigen} + URL_HASH SHA1=${DEP_SHA1_eigen} + ) + FetchContent_Populate(eigen) set(eigen_INCLUDE_DIRS "${eigen_SOURCE_DIR}") endif() diff --git a/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll b/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll new file mode 100644 index 0000000000000..686afedb50bf3 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll differ diff --git a/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll b/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll new file mode 100644 index 0000000000000..1750b9ce92d72 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll differ diff --git a/cmake/external/git.Win32.2.41.03.patch/patch.exe b/cmake/external/git.Win32.2.41.03.patch/patch.exe new file mode 100644 index 0000000000000..8d784cb5d7e40 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/patch.exe differ diff --git a/cmake/external/onnx b/cmake/external/onnx index e2525550194ce..b86cc54efce19 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e1671bcf43ed9..0fa5163dc06bf 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - # Set it to ON will cause crashes in onnxruntime_test_all when onnxruntime_USE_CUDA is ON - set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + # Needs to update onnxruntime/test/xctest/xcgtest.mm + set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + else() + set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) + endif() # gtest and gmock FetchContent_Declare( googletest @@ -331,6 +335,7 @@ if(onnxruntime_USE_CUDA) URL ${DEP_URL_microsoft_gsl} URL_HASH SHA1=${DEP_SHA1_microsoft_gsl} PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/gsl/1064.patch + FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL ) else() FetchContent_Declare( diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index 7455584f1a625..e661aa51bfc17 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -25,17 +25,23 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR}) FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool}) onnxruntime_fetchcontent_makeavailable(pthreadpool) -FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} -PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch) +FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch + ) onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) + # the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels + message("Adding WebAssembly Source Files to XNNPACK") + set(wasm_srcs "") + file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config) # Replace newlines with semicolon so that it is treated as a list by CMake @@ -70,25 +76,23 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(${target_srcs} ${bazel_srcs} PARENT_SCOPE) endfunction() - GetSrcListFromBazel("PROD_SCALAR_WASM_MICROKERNEL_SRCS" prod_scalar_wasm_srcs) - GetSrcListFromBazel("ALL_WASM_MICROKERNEL_SRCS" all_wasm_srcs) - GetSrcListFromBazel("WASM32_ASM_MICROKERNEL_SRCS" wasm32_asm_srcs) + GetSrcListFromBazel("OPERATOR_SRCS" operator_srcs) + GetSrcListFromBazel("TABLE_SRCS" table_srcs) + list(APPEND wasm_srcs ${operator_srcs} ${table_srcs}) - message(DEBUG "prod_scalar_wasm_srcs: ${prod_scalar_wasm_srcs}\n") - message(DEBUG "all_wasm_srcs: ${all_wasm_srcs}\n") - message(DEBUG "wasm32_asm_srcs: ${wasm32_asm_srcs}\n") + # kernels + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c) - message("Adding WebAssembly Source Files to XNNPACK") - set(wasm_srcs "") - list(APPEND wasm_srcs ${prod_scalar_wasm_srcs}) - list(APPEND wasm_srcs ${all_wasm_srcs}) - list(APPEND wasm_srcs ${wasm32_asm_srcs}) + if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c) + target_compile_options(XNNPACK PRIVATE "-msimd128") + endif() + message(DEBUG "wasm_srcs: ${wasm_srcs}\n") target_sources(XNNPACK PRIVATE ${wasm_srcs}) - if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - GetSrcListFromBazel("ALL_WASMSIMD_MICROKERNEL_SRCS" all_wasmsimd_srcs) - message(DEBUG "all_wasmsimd_srcs: ${all_wasmsimd_srcs}") - target_sources(XNNPACK PRIVATE ${all_wasmsimd_srcs}) - endif() + # add flags from BAZEL.build + target_compile_options(XNNPACK PRIVATE "-fno-fast-math") + target_compile_options(XNNPACK PRIVATE "-fno-math-errno") endif() diff --git a/cmake/linux_arm32_crosscompile_toolchain.cmake b/cmake/linux_arm32_crosscompile_toolchain.cmake new file mode 100644 index 0000000000000..0183262a8875e --- /dev/null +++ b/cmake/linux_arm32_crosscompile_toolchain.cmake @@ -0,0 +1,9 @@ + #This file is just a sample. You may need to modify it before using. + SET(CMAKE_SYSTEM_NAME Linux) + SET(CMAKE_SYSTEM_VERSION 1) + SET(CMAKE_C_COMPILER arm-none-linux-gnueabihf-gcc) + SET(CMAKE_CXX_COMPILER arm-none-linux-gnueabihf-g++) + SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) + SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) + SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + SET(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) \ No newline at end of file diff --git a/cmake/linux_arm64_crosscompile_toolchain.cmake b/cmake/linux_arm64_crosscompile_toolchain.cmake new file mode 100644 index 0000000000000..1a492bbc269e7 --- /dev/null +++ b/cmake/linux_arm64_crosscompile_toolchain.cmake @@ -0,0 +1,9 @@ + #This file is just a sample. You may need to modify it before using. + SET(CMAKE_SYSTEM_NAME Linux) + SET(CMAKE_SYSTEM_VERSION 1) + SET(CMAKE_C_COMPILER aarch64-none-linux-gnu-gcc) + SET(CMAKE_CXX_COMPILER aarch64-none-linux-gnu-g++) + SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) + SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) + SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + SET(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) \ No newline at end of file diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 59ebf8eca4306..c900f4d4b09a5 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -18,35 +18,21 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set(OUTPUT_STYLE xcode) endif() -set(ONNXRUNTIME_PUBLIC_HEADERS - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" -) - -if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") -endif() - -# This macro is to get the path of header files for mobile packaging, for iOS and Android -macro(get_mobile_api_headers _HEADERS) - # include both c and cxx api - set(${_HEADERS} +# Gets the public C/C++ API header files +function(get_c_cxx_api_headers HEADERS_VAR) + set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" ) if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") endif() # need to add header files for enabled EPs @@ -54,10 +40,13 @@ macro(get_mobile_api_headers _HEADERS) file(GLOB _provider_headers CONFIGURE_DEPENDS "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" ) - list(APPEND ${_HEADERS} "${_provider_headers}") - unset(_provider_headers) + list(APPEND _headers ${_provider_headers}) endforeach() -endmacro() + + set(${HEADERS_VAR} ${_headers} PARENT_SCOPE) +endfunction() + +get_c_cxx_api_headers(ONNXRUNTIME_PUBLIC_HEADERS) #If you want to verify if there is any extra line in symbols.txt, run # nm -C -g --defined libonnxruntime.so |grep -v '\sA\s' | cut -f 3 -d ' ' | sort @@ -84,11 +73,9 @@ if(WIN32) "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" ) elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) - get_mobile_api_headers(APPLE_FRAMEWORK_HEADERS) - # apple framework requires the header file be part of the library onnxruntime_add_shared_library(onnxruntime - ${APPLE_FRAMEWORK_HEADERS} + ${ONNXRUNTIME_PUBLIC_HEADERS} "${CMAKE_CURRENT_BINARY_DIR}/generated_source.c" ) @@ -107,10 +94,9 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) set_target_properties(onnxruntime PROPERTIES FRAMEWORK TRUE FRAMEWORK_VERSION A - PUBLIC_HEADER "${APPLE_FRAMEWORK_HEADERS}" - MACOSX_FRAMEWORK_INFO_PLIST ${CMAKE_CURRENT_BINARY_DIR}/Info.plist - VERSION ${ORT_VERSION} - SOVERSION ${ORT_VERSION} + MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} + SOVERSION ${ORT_VERSION} + # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. ) else() onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) @@ -180,11 +166,10 @@ endif() # we need to copy C/C++ API headers to be packed into Android AAR package if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) - get_mobile_api_headers(ANDROID_AAR_HEADERS) set(ANDROID_HEADERS_DIR ${CMAKE_CURRENT_BINARY_DIR}/android/headers) file(MAKE_DIRECTORY ${ANDROID_HEADERS_DIR}) # copy the header files one by one - foreach(h_ ${ANDROID_AAR_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${ANDROID_HEADERS_DIR}/${HEADER_NAME_}) endforeach() @@ -232,6 +217,7 @@ if (onnxruntime_USE_EXTENSIONS) list(APPEND onnxruntime_INTERNAL_LIBRARIES onnxruntime_extensions ocos_operators + noexcep_operators ) endif() @@ -255,7 +241,7 @@ install(TARGETS onnxruntime PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) @@ -296,44 +282,73 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) + + # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) - file(MAKE_DIRECTORY ${STATIC_LIB_DIR}) + set(STATIC_LIB_TEMP_DIR ${STATIC_LIB_DIR}/temp) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_TEMP_DIR}) + + set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_FRAMEWORK_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_FRAMEWORK_DIR}) - # Remove the existing files in the STATIC_LIB_DIR folder - file(GLOB _OLD_STATIC_LIBS ${STATIC_LIB_DIR}/*.a) - file(REMOVE "${_OLD_STATIC_LIBS}") + # replicate XCode's Single Object Pre-Link + # link the internal onnxruntime .o files with the external .a files into a single relocatable object + # to enforce symbol visibility. doing it this way limits the symbols included from the .a files to symbols used + # by the ORT .o files. - # Go through all the static libraries, and create symbolic links - foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ${onnxruntime_EXTERNAL_LIBRARIES}) + # If it's an onnxruntime library, extract .o files to a separate directory for each library to avoid any clashes + # with filenames (e.g. utils.o) + foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $ ${STATIC_LIB_DIR}/$) + set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${CUR_STATIC_LIB_OBJ_DIR}) + + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ar ARGS -x $ + WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) endif() endforeach() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) - else() # macOS - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - endif() + # for external libraries we create a symlink to the .a file + foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) + GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) + if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + $ ${STATIC_LIB_DIR}/$) + endif() + endforeach() - # Assemble the static framework - set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) - set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) - file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_DIR}) - # Remove all files under STATIC_FRAMEWORK_DIR (if any) - file(GLOB_RECURSE _OLD_STATIC_FRAMEWORK ${STATIC_FRAMEWORK_DIR}/*.*) - file(REMOVE "${_OLD_STATIC_FRAMEWORK}") + # do the pre-link with `ld -r` to create a single relocatable object with correct symbol visibility + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a + WORKING_DIRECTORY ${STATIC_LIB_TEMP_DIR}) + + # create the static library + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o + WORKING_DIRECTORY ${STATIC_LIB_DIR}) + # Assemble the other pieces of the static framework + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) + + # add the framework header files + set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_HEADER_DIR}) - # copy the header files one by one, and the Info.plist - foreach(h_ ${APPLE_FRAMEWORK_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) endforeach() - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) - # link the static library - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime *.a WORKING_DIRECTORY ${STATIC_LIB_DIR}) endif() diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index 2aef9dcf209e0..e3ea767401ddc 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -22,5 +22,5 @@ #cmakedefine HAS_UNUSED_BUT_SET_VARIABLE #cmakedefine HAS_UNUSED_VARIABLE #cmakedefine HAS_USELESS_CAST -#cmakedefine ORT_BUILD_INFO u8"@ORT_BUILD_INFO@" -#cmakedefine ORT_VERSION u8"@ORT_VERSION@" +#cmakedefine ORT_BUILD_INFO "@ORT_BUILD_INFO@" +#cmakedefine ORT_VERSION "@ORT_VERSION@" diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 735c86956ec4f..3f532ec2c3261 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -20,6 +20,8 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_deprecated_operators.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc" "${ONNXRUNTIME_ROOT}/core/graph/function_template.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc" diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e0ccc504d7b27..04efa5c2b4f6d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -68,6 +69,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -334,17 +336,24 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") @@ -581,7 +590,7 @@ set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize" 131072>) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") endif() endif() diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 3da4198573d54..6f09583199ffd 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -86,6 +86,8 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/core/optimizer/*.cc" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.h" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.cc" ) endif() @@ -109,6 +111,9 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module) + endif() endif() if (onnxruntime_ENABLE_TRITON) target_link_libraries(onnxruntime_optimizer PRIVATE nlohmann_json::nlohmann_json) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index b9e7873132089..8d3ea403fb74b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -59,50 +59,6 @@ function(add_op_reduction_include_dirs target) endfunction() -file(GLOB_RECURSE onnxruntime_providers_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cpu/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cpu/*.cc" -) - -if(onnxruntime_DISABLE_ML_OPS) - list(FILTER onnxruntime_providers_srcs EXCLUDE REGEX ".*/ml/.*") -endif() - -file(GLOB_RECURSE onnxruntime_cpu_contrib_ops_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.cc" -) - -file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" -) - -file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu" - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" -) - -file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cc" -) - -file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cu" - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh" -) - -file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" -) - -file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/op_kernel_type_control_overrides.inc" -) if(onnxruntime_USE_VITISAI) set(PROVIDERS_VITISAI onnxruntime_providers_vitisai) endif() @@ -155,9 +111,6 @@ endif() if(onnxruntime_USE_WEBNN) set(PROVIDERS_WEBNN onnxruntime_providers_webnn) endif() -if(onnxruntime_USE_SNPE) - include(onnxruntime_snpe_provider.cmake) -endif() if (onnxruntime_USE_CANN) set(PROVIDERS_CANN onnxruntime_providers_cann) endif() @@ -165,1725 +118,88 @@ if (onnxruntime_USE_AZURE) set(PROVIDERS_AZURE onnxruntime_providers_azure) endif() -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) - -set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) - -# disable contrib ops conditionally -if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - if (NOT onnxruntime_ENABLE_ATEN) - list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.cc" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" - ) - endif() - # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) - list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) -endif() - -if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) - file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" - ) - - source_group(TREE ${ORTTRAINING_ROOT}/ FILES ${onnxruntime_cpu_training_ops_srcs}) - list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_training_ops_srcs}) - - file(GLOB_RECURSE onnxruntime_cpu_full_training_only_srcs - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/communication/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/communication/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/record.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/record.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/wait.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/wait.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/yield.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/yield.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/triton/triton_op.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/triton/triton_op.h" - ) - - list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs}) -endif() - -if (onnxruntime_ENABLE_ATEN) - file(GLOB_RECURSE onnxruntime_providers_dlpack_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.cc" - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.h" - ) - set(onnxruntime_providers_dlpack_srcs ${onnxruntime_providers_dlpack_srcs}) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dlpack_srcs}) - list(APPEND onnxruntime_providers_src ${onnxruntime_providers_dlpack_srcs}) -endif() - -if (onnxruntime_ENABLE_TRAINING) - file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" - "${ORTTRAINING_SOURCE_DIR}/core/framework/*.h" - "${ORTTRAINING_SOURCE_DIR}/core/framework/*.cc" - "${ORTTRAINING_SOURCE_DIR}/core/framework/adasum/*" - "${ORTTRAINING_SOURCE_DIR}/core/framework/communication/*" - ) - - # This is already built in framework.cmake - file(GLOB_RECURSE onnxruntime_training_framework_excude_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.h" - "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.cc" - "${ORTTRAINING_SOURCE_DIR}/core/framework/triton/*.h" - "${ORTTRAINING_SOURCE_DIR}/core/framework/triton/*.cc" - ) - - list(REMOVE_ITEM onnxruntime_cpu_training_ops_srcs ${onnxruntime_training_framework_excude_srcs}) - - source_group(TREE ${ORTTRAINING_ROOT}/ FILES ${onnxruntime_cpu_training_ops_srcs}) - list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_training_ops_srcs}) -endif() - -if (onnxruntime_REDUCED_OPS_BUILD) - substitute_op_reduction_srcs(onnxruntime_providers_src) -endif() -onnxruntime_add_static_library(onnxruntime_providers ${onnxruntime_providers_src}) -if (onnxruntime_REDUCED_OPS_BUILD) - add_op_reduction_include_dirs(onnxruntime_providers) -endif() - -if (HAS_BITWISE_INSTEAD_OF_LOGICAL) - target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") -endif() - -if (MSVC) - target_compile_options(onnxruntime_providers PRIVATE "/bigobj") -# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) -# target_compile_options(onnxruntime_providers PRIVATE "/wd4244") -# endif() -endif() -onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - -if (onnxruntime_BUILD_MS_EXPERIMENTAL_OPS) - target_compile_definitions(onnxruntime_providers PRIVATE BUILD_MS_EXPERIMENTAL_OPS=1) -endif() - -if(HAS_DEPRECATED_COPY) - #temporarily ignore this warning - #see: https://en.wikipedia.org/wiki/Rule_of_three_(C%2B%2B_programming) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/matmul_integer.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/quantize_linear_matmul.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/qlinearconv.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/conv_integer.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/generator/random.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/onehot.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) - set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/where_op.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) -endif() - -# This is enabled only for Adasum files in training mode. -# The flags won't be applied globally since some high-precision training and inferencing ops will incur precision loss. -if (onnxruntime_ENABLE_CPU_FP16_OPS) - set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/core/framework/adasum/adasum_mpi.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") - set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") - set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") -endif() - -target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) -onnxruntime_add_include_to_target(onnxruntime_providers re2::re2) -add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - -if (onnxruntime_ENABLE_TRAINING_OPS) - target_include_directories(onnxruntime_providers PRIVATE ${ORTTRAINING_ROOT}) -endif() - -if (onnxruntime_ENABLE_ATEN) - target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_ATEN) - # DLPack is a header-only dependency - set(DLPACK_INCLUDE_DIR ${dlpack_SOURCE_DIR}/include) - target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR}) -endif() - -if (onnxruntime_ENABLE_TRAINING) - add_dependencies(onnxruntime_providers tensorboard) - onnxruntime_add_include_to_target(onnxruntime_providers tensorboard) - if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OR onnxruntime_ENABLE_TRITON) - onnxruntime_add_include_to_target(onnxruntime_providers Python::Module) - endif() - if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI) - target_include_directories(onnxruntime_providers PUBLIC ${MPI_CXX_INCLUDE_DIRS}) - endif() -endif() - -install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu/cpu_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) -set_target_properties(onnxruntime_providers PROPERTIES LINKER_LANGUAGE CXX) -set_target_properties(onnxruntime_providers PROPERTIES FOLDER "ONNXRuntime") - -if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD - AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS" - AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android" - AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - file(GLOB onnxruntime_providers_shared_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/*.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_shared_cc_srcs}) - onnxruntime_add_shared_library(onnxruntime_providers_shared ${onnxruntime_providers_shared_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") - set_target_properties(onnxruntime_providers_shared PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_shared PROPERTIES LINKER_LANGUAGE CXX) - - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MINOR=${VERSION_MINOR_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_BUILD=${VERSION_BUILD_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_STRING=\"${VERSION_STRING}\") - target_compile_definitions(onnxruntime_providers_shared PRIVATE FILE_NAME=\"onnxruntime_providers_shared.dll\") - - - # On Apple/Unix we don't directly link with this library as we load it with RTLD_GLOBAL, so this is only set to the actual library on WIN32 - # But, in exchange we need to manually add Boost::mp11 to include dirs for every EP. - # It is because "provider_api.h" includes core/framework/op_kernel.h which includes op_kernel.h which includes "boost/mp11.hpp" - set(ONNXRUNTIME_PROVIDERS_SHARED) - - if(APPLE) - set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst") - elseif(UNIX) - set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections") - elseif(WIN32) - set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def") - set(ONNXRUNTIME_PROVIDERS_SHARED onnxruntime_providers_shared) - else() - message(FATAL_ERROR "onnxruntime_providers_shared unknown platform, need to specify shared library exports for it") - endif() - - install(TARGETS onnxruntime_providers_shared - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} - ) +if(onnxruntime_USE_SNPE) + include(onnxruntime_snpe_provider.cmake) endif() +include(onnxruntime_providers_cpu.cmake) if (onnxruntime_USE_CUDA) - file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" - ) - # Remove pch files - list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" - ) - - # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. - file(GLOB_RECURSE onnxruntime_providers_cuda_shared_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) - set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) - - # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - if (NOT onnxruntime_ENABLE_ATEN) - list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" - ) - endif() - if (NOT onnxruntime_USE_NCCL) - list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" - ) - endif() - # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) - list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) - endif() - - if (onnxruntime_ENABLE_TRAINING_OPS) - file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" - ) - - file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh" - ) - - source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) - list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) - - if(NOT onnxruntime_ENABLE_TRAINING) - file(GLOB_RECURSE onnxruntime_cuda_full_training_only_srcs - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cu" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/triton/triton_op.cc" - ) - - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_full_training_only_srcs}) - elseif(WIN32 OR NOT onnxruntime_USE_NCCL) - # NCCL is not support in Windows build - file(GLOB_RECURSE onnxruntime_cuda_nccl_op_srcs - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" - ) - - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs}) - endif() - endif() - - if (onnxruntime_REDUCED_OPS_BUILD) - substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) - endif() - # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. - # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. - set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) - onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) - # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. - # This function guarantees that all 3 targets have the same configurations. - function(config_cuda_provider_shared_module target) - if (onnxruntime_REDUCED_OPS_BUILD) - add_op_reduction_include_dirs(${target}) - endif() - - if (HAS_GUARD_CF) - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") - endif() - if (HAS_QSPECTRE) - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") - endif() - foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") - endforeach() - # CUDA 11.3+ supports parallel compilation - # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads - if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) - target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") - endif() - if (UNIX) - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" - "$<$>:-Wno-reorder>") - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" - "$<$>:-Wno-error=sign-compare>") - else() - #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") - target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") - endif() - - onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) - if (onnxruntime_ENABLE_TRAINING_OPS) - onnxruntime_add_include_to_target(${target} onnxruntime_training) - if (onnxruntime_ENABLE_TRAINING) - target_link_libraries(${target} PRIVATE onnxruntime_training) - endif() - if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OR onnxruntime_ENABLE_TRITON) - onnxruntime_add_include_to_target(${target} Python::Module) - endif() - endif() - - add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) - endif() - - if (onnxruntime_USE_TRITON_KERNEL) - # compile triton kernel, generate .a and .h files - include(onnxruntime_compile_triton_kernel.cmake) - compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) - add_dependencies(${target} onnxruntime_triton_kernel) - target_compile_definitions(${target} PRIVATE USE_TRITON_KERNEL) - target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) - target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) - # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) - endif() - - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - endif() - - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found - set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) - set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") - - if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) - endif() - - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) - endif() - - if (onnxruntime_ENABLE_TRAINING_OPS) - target_include_directories(${target} PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS}) - endif() - - if(onnxruntime_USE_MPI) - target_link_libraries(${target} PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) - endif() - - if (onnxruntime_USE_NCCL) - target_include_directories(${target} PRIVATE ${NCCL_INCLUDE_DIRS}) - target_link_libraries(${target} PRIVATE ${NCCL_LIBRARIES}) - endif() - - if (WIN32) - # *.cu cannot use PCH - if (NOT onnxruntime_BUILD_CACHE) - target_precompile_headers(${target} PUBLIC - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" - ) - endif() - - # minimize the Windows includes. - # this avoids an issue with CUDA 11.6 where 'small' is defined in the windows and cuda headers. - target_compile_definitions(${target} PRIVATE "WIN32_LEAN_AND_MEAN") - - # disable a warning from the CUDA headers about unreferenced local functions - #target_compile_options(${target} PRIVATE /wd4505) - set(onnxruntime_providers_cuda_static_library_flags - -IGNORE:4221 # LNK4221: This object file does not define any previously undefined public symbols, so it will not be used by any link operation that consumes this library - ) - set_target_properties(${target} PROPERTIES - STATIC_LIBRARY_FLAGS "${onnxruntime_providers_cuda_static_library_flags}") - endif() - - if(APPLE) - set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst") - target_link_libraries(${target} PRIVATE nsync::nsync_cpp) - elseif(UNIX) - set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections") - target_link_libraries(${target} PRIVATE nsync::nsync_cpp) - elseif(WIN32) - set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def") - else() - message(FATAL_ERROR "${target} unknown platform, need to specify shared library exports for it") - endif() - - if (onnxruntime_ENABLE_ATEN) - target_compile_definitions(${target} PRIVATE ENABLE_ATEN) - endif() - endfunction() - config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) - config_cuda_provider_shared_module(onnxruntime_providers_cuda) - - install(TARGETS onnxruntime_providers_cuda - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) - + include(onnxruntime_providers_cuda.cmake) endif() if (onnxruntime_USE_DNNL) - file(GLOB_RECURSE onnxruntime_providers_dnnl_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/dnnl/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/dnnl/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dnnl_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_dnnl ${onnxruntime_providers_dnnl_cc_srcs}) - target_link_directories(onnxruntime_providers_dnnl PRIVATE ${DNNL_LIB_DIR}) - if (MSVC AND onnxruntime_ENABLE_STATIC_ANALYSIS) - # dnnl_convgrad.cc(47,0): Warning C6262: Function uses '38816' bytes of stack: exceeds /analyze:stacksize '16384'. Consider moving some data to heap. - target_compile_options(onnxruntime_providers_dnnl PRIVATE "/analyze:stacksize 131072") - endif() - - add_dependencies(onnxruntime_providers_dnnl onnxruntime_providers_shared project_dnnl ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${DNNL_INCLUDE_DIR} ${DNNL_OCL_INCLUDE_DIR}) - # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found - target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET} safeint_interface) - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dnnl/dnnl_provider_options.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - set_target_properties(onnxruntime_providers_dnnl PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX) - - # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING_OPS) - target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ORTTRAINING_ROOT}) - endif() - - # Needed for threadpool handling - if(onnxruntime_BUILD_JAVA) - add_compile_definitions(DNNL_JAVA) - endif() - - if(APPLE) - set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/dnnl/exported_symbols.lst") - set_target_properties(onnxruntime_providers_dnnl PROPERTIES - INSTALL_RPATH "@loader_path" - BUILD_WITH_INSTALL_RPATH TRUE - INSTALL_RPATH_USE_LINK_PATH FALSE) - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) - elseif(UNIX) - set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN") - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) - elseif(WIN32) - set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def") - else() - message(FATAL_ERROR "onnxruntime_providers_dnnl unknown platform, need to specify shared library exports for it") - endif() - - install(TARGETS onnxruntime_providers_dnnl - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + include(onnxruntime_providers_dnnl.cmake) endif() if (onnxruntime_USE_TENSORRT) - add_definitions(-DUSE_TENSORRT=1) - if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER) - add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER) - endif() - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) - set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) - if (WIN32) - add_definitions(-D_SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING=1) - set(OLD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4099 /wd4551 /wd4505 /wd4515 /wd4706 /wd4456 /wd4324 /wd4701 /wd4804 /wd4702 /wd4458 /wd4703") - if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4805") - endif() - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -include algorithm") - set(DISABLED_WARNINGS_FOR_TRT /wd4456) - endif() - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") - endif() - set(CXX_VERSION_DEFINED TRUE) - - # There is an issue when running "Debug build" TRT EP with "Release build" TRT builtin parser on Windows. - # We enforce following workaround for now until the real fix. - if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") - set(onnxruntime_USE_TENSORRT_BUILTIN_PARSER OFF) - MESSAGE(STATUS "[Note] There is an issue when running \"Debug build\" TRT EP with \"Release build\" TRT built-in parser on Windows. This build will use tensorrt oss parser instead.") - endif() - - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - # Add TensorRT library - find_path(TENSORRT_INCLUDE_DIR NvInfer.h - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES include) - MESSAGE(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") - find_library(TENSORRT_LIBRARY_INFER nvinfer - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - find_library(TENSORRT_LIBRARY_NVONNXPARSER nvonnxparser - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) - MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") - else() - FetchContent_Declare( - onnx_tensorrt - URL ${DEP_URL_onnx_tensorrt} - URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} - ) - # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses - # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. - onnxruntime_fetchcontent_makeavailable(onnx_tensorrt) - include_directories(${onnx_tensorrt_SOURCE_DIR}) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() - if (WIN32) - set(CMAKE_CUDA_FLAGS ${OLD_CMAKE_CUDA_FLAGS}) - unset(PROTOBUF_LIBRARY) - unset(OLD_CMAKE_CXX_FLAGS) - unset(OLD_CMAKE_CUDA_FLAGS) - set_target_properties(nvonnxparser PROPERTIES LINK_FLAGS "/ignore:4199") - target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100) - target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100) - endif() - set(onnxparser_link_libs nvonnxparser_static) - endif() - - include_directories(${TENSORRT_INCLUDE_DIR}) - # ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static. - # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt - # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) - - file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_tensorrt ${onnxruntime_providers_tensorrt_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) - else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) - endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - if(onnxruntime_CUDNN_HOME) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) - endif() - - # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found - set_target_properties(onnxruntime_providers_tensorrt PROPERTIES PUBLIC_HEADER ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h) - set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) - set_target_properties(onnxruntime_providers_tensorrt PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE ONNXIFI_BUILD_LIBRARY=1) - target_compile_options(onnxruntime_providers_tensorrt PRIVATE ${DISABLED_WARNINGS_FOR_TRT}) - if (WIN32) - target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) - endif() - - # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING_OPS) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT}) - if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module) - endif() - endif() - - if(APPLE) - set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp) - elseif(UNIX) - set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") - set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp stdc++fs) - elseif(WIN32) - set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def") - else() - message(FATAL_ERROR "onnxruntime_providers_tensorrt unknown platform, need to specify shared library exports for it") - endif() - - install(TARGETS onnxruntime_providers_tensorrt - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}) + include(onnxruntime_providers_tensorrt.cmake) endif() if (onnxruntime_USE_VITISAI) - if ("${GIT_COMMIT_ID}" STREQUAL "") - execute_process( - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMAND git rev-parse HEAD - OUTPUT_VARIABLE GIT_COMMIT_ID - OUTPUT_STRIP_TRAILING_WHITESPACE) - endif() - configure_file(${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/version_info.hpp.in ${CMAKE_CURRENT_BINARY_DIR}/VitisAI/version_info.h) - file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" - ) - list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) - onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) - target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") - target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) - target_compile_definitions(onnxruntime_vitisai_ep - PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) - - target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) - if(MSVC) - target_compile_options(onnxruntime_providers_vitisai PRIVATE "/Zc:__cplusplus") - # for dll interface warning. - target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") - # for unused formal parameter - target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") - else(MSVC) - target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) - endif(MSVC) - - set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_vitisai.cmake) endif() if (onnxruntime_USE_OPENVINO) - -# include_directories("${CMAKE_CURRENT_BINARY_DIR}/onnx") - file(GLOB_RECURSE onnxruntime_providers_openvino_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.hpp" - "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.cpp" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - - if (WIN32) - set(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release) - endif() - - # Header paths - find_package(InferenceEngine REQUIRED) - find_package(ngraph REQUIRED) - - if (OPENVINO_2022_1 OR OPENVINO_2022_2) - find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - list (OV_20_LIBS openvino::frontend::onnx openvino::runtime) - endif() - - if (WIN32) - unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) - endif() - - if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS})) - add_definitions(-DIO_BUFFER_ENABLED=1) - list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS} ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) - else() - list(APPEND OPENVINO_LIB_LIST ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) - endif() - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") - onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx) - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/openvino/openvino_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - set_target_properties(onnxruntime_providers_openvino PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_openvino PROPERTIES FOLDER "ONNXRuntime") - if(NOT MSVC) - target_compile_options(onnxruntime_providers_openvino PRIVATE "-Wno-parentheses") - endif() - add_dependencies(onnxruntime_providers_openvino onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) - target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS}) - - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MINOR=${VERSION_MINOR_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_BUILD=${VERSION_BUILD_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_STRING=\"${VERSION_STRING}\") - target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") - - if(MSVC) - target_compile_options(onnxruntime_providers_openvino PUBLIC /wd4099 /wd4275 /wd4100 /wd4005 /wd4244 /wd4267) - endif() - - # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING_OPS) - target_include_directories(onnxruntime_providers_openvino PRIVATE ${ORTTRAINING_ROOT}) - endif() - - if(APPLE) - set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/openvino/exported_symbols.lst") - elseif(UNIX) - set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/openvino/version_script.lds -Xlinker --gc-sections") - elseif(WIN32) - set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/openvino/symbols.def") - else() - message(FATAL_ERROR "onnxruntime_providers_openvino unknown platform, need to specify shared library exports for it") - endif() - - install(TARGETS onnxruntime_providers_openvino - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + include(onnxruntime_providers_openvino.cmake) endif() if (onnxruntime_USE_COREML) - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() - - add_compile_definitions(USE_COREML=1) - - # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) - file(GLOB coreml_proto_srcs - "${COREML_PROTO_ROOT}/*.proto" - ) - onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) - target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") - target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - set(_src_sub_dir "coreml/") - onnxruntime_protobuf_generate( - APPEND_PATH - GEN_SRC_SUB_DIR ${_src_sub_dir} - IMPORT_DIRS ${COREML_PROTO_ROOT} - TARGET onnxruntime_coreml_proto - ) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_coreml_proto - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() - endif() - - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - ) - - file(GLOB - onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" - ) - - # Add builder source code - file(GLOB_RECURSE - onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" - ) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" - ) - endif() - - # Add CoreML objective c++ source code - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - file(GLOB - onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" - ) - endif() - - set(onnxruntime_providers_coreml_cc_srcs - ${onnxruntime_providers_coreml_cc_srcs_top} - ${onnxruntime_providers_coreml_cc_srcs_nested} - ${onnxruntime_providers_shared_utils_cc_srcs} - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} - ) - onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) - endif() - add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) - set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_coreml - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_coreml.cmake) endif() if (onnxruntime_USE_WEBNN) - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "WebNN EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() - - add_compile_definitions(USE_WEBNN=1) - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) - endif() - file(GLOB_RECURSE onnxruntime_providers_webnn_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - ) - - source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webnn_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_webnn ${onnxruntime_providers_webnn_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_webnn onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - - add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) + include(onnxruntime_providers_webnn.cmake) endif() if (onnxruntime_USE_NNAPI_BUILTIN) - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() - - add_compile_definitions(USE_NNAPI=1) - - # This is the minimum Android API Level required by ORT NNAPI EP to run - # ORT running on any host system with Android API level less than this will fall back to CPU EP - if(onnxruntime_NNAPI_MIN_API) - add_compile_definitions(ORT_NNAPI_MIN_API_LEVEL=${onnxruntime_NNAPI_MIN_API}) - endif() - - # This is the maximum Android API level supported in the ort model conversion for NNAPI EP - # Note: This is only for running NNAPI for ort format model conversion on non-Android system since we cannot - # get the actually Android system version. - if(onnxruntime_NNAPI_HOST_API) - if(CMAKE_SYSTEM_NAME STREQUAL "Android") - message(FATAL_ERROR "onnxruntime_NNAPI_HOST_API should only be set for non-Android target") - endif() - add_compile_definitions(ORT_NNAPI_MAX_SUPPORTED_API_LEVEL=${onnxruntime_NNAPI_HOST_API}) - endif() - - set(onnxruntime_provider_nnapi_cc_src_patterns - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/impl/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/impl/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.cc" - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" - ) - - # On Android, use the actual NNAPI implementation. - # Otherwise, use a stub implementation to support some unit testing. - if(CMAKE_SYSTEM_NAME STREQUAL "Android") - list(APPEND onnxruntime_provider_nnapi_cc_src_patterns - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.cc") - else() - list(APPEND onnxruntime_provider_nnapi_cc_src_patterns - "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation_stub.cc") - endif() - - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML - list(APPEND onnxruntime_provider_nnapi_cc_src_patterns - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" - ) - - file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_nnapi_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_nnapi ${onnxruntime_providers_nnapi_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_nnapi - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - target_link_libraries(onnxruntime_providers_nnapi) - add_dependencies(onnxruntime_providers_nnapi onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_nnapi PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_nnapi PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_nnapi PRIVATE ${ONNXRUNTIME_ROOT} ${nnapi_INCLUDE_DIRS}) - set_target_properties(onnxruntime_providers_nnapi PROPERTIES LINKER_LANGUAGE CXX) - # ignore the warning unknown-pragmas on "pragma region" - if(NOT MSVC) - target_compile_options(onnxruntime_providers_nnapi PRIVATE "-Wno-unknown-pragmas") - endif() - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_nnapi - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_nnapi.cmake) endif() if (onnxruntime_USE_JSEP) - add_compile_definitions(USE_JSEP=1) - - file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs - "${ONNXRUNTIME_ROOT}/core/providers/js/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/js/*.cc" - ) - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_js_contrib_ops_cc_srcs}) - list(APPEND onnxruntime_providers_js_cc_srcs ${onnxruntime_js_contrib_ops_cc_srcs}) - endif() - - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_js_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_js - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 - ) - target_include_directories(onnxruntime_providers_js PRIVATE ${eigen_INCLUDE_DIRS}) - add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) - + include(onnxruntime_providers_js.cmake) endif() if (onnxruntime_USE_QNN) - add_compile_definitions(USE_QNN=1) - - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" - ) - - file(GLOB_RECURSE - onnxruntime_providers_qnn_ep_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.cc" - ) - - file(GLOB_RECURSE - onnxruntime_providers_qnn_builder_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.cc" - ) - - set(onnxruntime_providers_qnn_cc_srcs - ${onnxruntime_providers_shared_utils_cc_srcs} - ${onnxruntime_providers_qnn_ep_cc_srcs} - ${onnxruntime_providers_qnn_builder_cc_srcs} - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11) - target_link_libraries(onnxruntime_providers_qnn) - add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_QNN_HOME}/include/QNN ${onnxruntime_QNN_HOME}/include) - set_target_properties(onnxruntime_providers_qnn PROPERTIES LINKER_LANGUAGE CXX) - # ignore the warning unknown-pragmas on "pragma region" - if(NOT MSVC) - target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") - endif() + include(onnxruntime_providers_qnn.cmake) endif() if (onnxruntime_USE_RKNPU) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable -Wno-unused-parameter") - add_definitions(-DUSE_RKNPU=1) - option(DNN_READ_ONNX "" ON) - set(DNN_CUSTOM_PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - option(DNN_CMAKE_INSTALL "" OFF) - option(DNN_BUILD_BIN "" OFF) - if (NOT RKNPU_DDK_PATH) - message(FATAL_ERROR "RKNPU_DDK_PATH required for onnxruntime_USE_RKNPU") - endif() - set(RKNPU_DDK_INCLUDE_DIR ${RKNPU_DDK_PATH}/include) - if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(RKNPU_DDK_LIB_DIR ${RKNPU_DDK_PATH}/lib64) - else() - set(RKNPU_DDK_LIB_DIR ${RKNPU_DDK_PATH}/lib) - endif() - file(GLOB_RECURSE - onnxruntime_providers_rknpu_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/rknpu/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/rknpu/*.cc" - ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_rknpu_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_rknpu ${onnxruntime_providers_rknpu_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_rknpu - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - target_link_libraries(onnxruntime_providers_rknpu PRIVATE -lrknpu_ddk) - add_dependencies(onnxruntime_providers_rknpu onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_rknpu PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_rknpu PRIVATE - ${ONNXRUNTIME_ROOT} ${rknpu_INCLUDE_DIRS} ${RKNPU_DDK_INCLUDE_DIR} - ) - link_directories(onnxruntime_providers_rknpu ${RKNPU_DDK_LIB_DIR}) - set_target_properties(onnxruntime_providers_rknpu PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_rknpu - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_rknpu.cmake) endif() if (onnxruntime_USE_DML) - file(GLOB_RECURSE onnxruntime_providers_dml_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/dml/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/dml/*.cpp" - "${ONNXRUNTIME_ROOT}/core/providers/dml/*.cc" - ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dml_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_dml ${onnxruntime_providers_dml_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_dml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${WIL_TARGET} - ) - add_dependencies(onnxruntime_providers_dml ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_include_directories(onnxruntime_providers_dml PRIVATE - ${ONNXRUNTIME_ROOT} - ) - - target_compile_definitions(onnxruntime_providers_dml PRIVATE DML_TARGET_VERSION_USE_LATEST=1) - if(WIN32) - target_compile_options(onnxruntime_providers_dml PRIVATE "/wd4100" "/wd4238" "/wd4189" "/wd4702") - endif() - - if (NOT onnxruntime_USE_CUSTOM_DIRECTML) - foreach(file "DirectML.dll" "DirectML.pdb" "DirectML.Debug.dll" "DirectML.Debug.pdb") - add_custom_command(TARGET onnxruntime_providers_dml - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different - "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/${file}" $) - endforeach() - endif() - - function(target_add_dml target) - if (onnxruntime_USE_CUSTOM_DIRECTML) - if (dml_EXTERNAL_PROJECT) - # Internal build of DirectML: link against the "DirectML" target. - target_link_libraries(${target} PRIVATE DirectML) - else() - if (dml_LIB_DIR) - target_link_libraries(${target} PRIVATE ${dml_LIB_DIR}/DirectML.lib) - else() - target_link_libraries(${target} PRIVATE DirectML) - endif() - endif() - else() - add_dependencies(${target} RESTORE_PACKAGES) - target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/DirectML.lib") - target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST) - endif() - endfunction() - - target_add_dml(onnxruntime_providers_dml) - target_link_libraries(onnxruntime_providers_dml PRIVATE onnxruntime_common) - target_link_libraries(onnxruntime_providers_dml PRIVATE onnxruntime_framework) - onnxruntime_add_include_to_target(onnxruntime_providers_dml onnxruntime_common) - if (GDK_PLATFORM STREQUAL Scarlett) - target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) - else() - target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib) - endif() - - target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) - - if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199") - endif() - - target_compile_definitions(onnxruntime_providers_dml - PRIVATE - ONNX_NAMESPACE=onnx ONNX_ML LOTUS_LOG_THRESHOLD=2 LOTUS_ENABLE_STDERR_LOGGING PLATFORM_WINDOWS - ) - target_compile_definitions(onnxruntime_providers_dml PRIVATE UNICODE _UNICODE NOMINMAX) - if (MSVC) - target_compile_definitions(onnxruntime_providers_dml PRIVATE _SILENCE_CXX17_ITERATOR_BASE_CLASS_DEPRECATION_WARNING) - target_compile_options(onnxruntime_providers_dml PRIVATE "/W3") - endif() - - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dml/dml_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/ - ) - - set_target_properties(onnxruntime_providers_dml PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime") - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_dml - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_dml.cmake) endif() if (onnxruntime_USE_MIGRAPHX) - add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) - set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") - endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() - - # Add search paths for default rocm installation - list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) - - find_package(hip) - find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME}) - - find_package(miopen) - find_package(rocblas) - - set(migraphx_libs migraphx::c hip::host MIOpen roc::rocblas) - - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.h" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.cc" - ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) - set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) - target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) - - include(CheckLibraryExists) - check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) - if(HAS_STREAM_SYNC) - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) - message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") - else() - message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") - endif() - - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + include(onnxruntime_providers_migraphx.cmake) endif() if (onnxruntime_USE_ACL) - add_definitions(-DUSE_ACL=1) - file(GLOB_RECURSE onnxruntime_providers_acl_cc_srcs - "${ONNXRUNTIME_ROOT}/core/providers/acl/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/acl/*.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_acl_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_acl ${onnxruntime_providers_acl_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_acl - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - - target_link_libraries(onnxruntime_providers_acl -L$ENV{LD_LIBRARY_PATH}) - add_dependencies(onnxruntime_providers_acl ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_acl PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_acl - PRIVATE - ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_ACL_HOME} ${onnxruntime_ACL_HOME}/include - ) - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/acl/acl_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/ - ) - set_target_properties(onnxruntime_providers_acl PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_acl - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_acl.cmake) endif() if (onnxruntime_USE_ARMNN) - add_definitions(-DUSE_ARMNN=1) - file(GLOB_RECURSE onnxruntime_providers_armnn_cc_srcs - "${ONNXRUNTIME_ROOT}/core/providers/armnn/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/armnn/*.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_armnn_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_armnn ${onnxruntime_providers_armnn_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_armnn - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - - add_dependencies(onnxruntime_providers_armnn ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_armnn PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_armnn PRIVATE - ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_ARMNN_HOME} ${onnxruntime_ARMNN_HOME}/include - ${onnxruntime_ACL_HOME} ${onnxruntime_ACL_HOME}/include - ) - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/armnn/armnn_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - - set_target_properties(onnxruntime_providers_armnn PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_armnn - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_armnn.cmake) endif() if (onnxruntime_USE_ROCM) - add_definitions(-DUSE_ROCM=1) - include(onnxruntime_rocm_hipify.cmake) - - list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME}) - - find_package(HIP) - find_package(hiprand REQUIRED) - find_package(rocblas REQUIRED) - find_package(MIOpen REQUIRED) - - # MIOpen version - if(NOT DEFINED ENV{MIOPEN_PATH}) - set(MIOPEN_PATH ${onnxruntime_ROCM_HOME}) - else() - set(MIOPEN_PATH $ENV{MIOPEN_PATH}) - endif() - find_path(MIOPEN_VERSION_H_PATH - NAMES version.h - HINTS - ${MIOPEN_PATH}/include/miopen - ${MIOPEN_PATH}/miopen/include) - if (MIOPEN_VERSION_H_PATH-NOTFOUND) - MESSAGE(FATAL_ERROR "miopen version.h not found") - endif() - MESSAGE(STATUS "Found miopen version.h at ${MIOPEN_VERSION_H_PATH}") - - file(READ ${MIOPEN_VERSION_H_PATH}/version.h MIOPEN_HEADER_CONTENTS) - string(REGEX MATCH "define MIOPEN_VERSION_MAJOR * +([0-9]+)" - MIOPEN_VERSION_MAJOR "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_MAJOR * +([0-9]+)" "\\1" - MIOPEN_VERSION_MAJOR "${MIOPEN_VERSION_MAJOR}") - string(REGEX MATCH "define MIOPEN_VERSION_MINOR * +([0-9]+)" - MIOPEN_VERSION_MINOR "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_MINOR * +([0-9]+)" "\\1" - MIOPEN_VERSION_MINOR "${MIOPEN_VERSION_MINOR}") - string(REGEX MATCH "define MIOPEN_VERSION_PATCH * +([0-9]+)" - MIOPEN_VERSION_PATCH "${MIOPEN_HEADER_CONTENTS}") - string(REGEX REPLACE "define MIOPEN_VERSION_PATCH * +([0-9]+)" "\\1" - MIOPEN_VERSION_PATCH "${MIOPEN_VERSION_PATCH}") - set(MIOPEN_VERSION_DEV "${MIOPEN_VERSION_MAJOR}.${MIOPEN_VERSION_MINOR}.${MIOPEN_VERSION_PATCH}") - math(EXPR MIOPEN_VERSION_DEV_INT "(${MIOPEN_VERSION_MAJOR}*10000) + (${MIOPEN_VERSION_MINOR}*100) + ${MIOPEN_VERSION_PATCH}") - message("MIOPEN_VERSION_DEV: ${MIOPEN_VERSION_DEV}") - message("MIOPEN_VERSION_DEV_INT: ${MIOPEN_VERSION_DEV_INT}") - add_definitions(-DMIOPEN_VERSION=${MIOPEN_VERSION_DEV_INT}) - - find_library(RCCL_LIB rccl REQUIRED) - find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB}) - - file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cc" - ) - - # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. - file(GLOB_RECURSE onnxruntime_providers_rocm_shared_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - - file(GLOB_RECURSE onnxruntime_providers_rocm_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cuh" - ) - - hipify("onnxruntime/core/providers" provider_excluded_files onnxruntime_providers_rocm_generated_cc_srcs onnxruntime_providers_rocm_generated_cu_srcs) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) - set(onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_generated_cc_srcs} ${onnxruntime_providers_rocm_generated_cu_srcs}) - - # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - hipify("onnxruntime/contrib_ops" contrib_ops_excluded_files onnxruntime_rocm_generated_contrib_ops_cc_srcs onnxruntime_rocm_generated_contrib_ops_cu_srcs) - - # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio - source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} ${onnxruntime_rocm_generated_contrib_ops_cu_srcs}) - endif() - - if (onnxruntime_ENABLE_TRAINING_OPS) - file(GLOB_RECURSE onnxruntime_rocm_training_ops_cc_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cc" - ) - - file(GLOB_RECURSE onnxruntime_rocm_training_ops_cu_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cu" - "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cuh" - ) - - hipify("orttraining/orttraining/training_ops" training_ops_excluded_files onnxruntime_rocm_generated_training_ops_cc_srcs onnxruntime_rocm_generated_training_ops_cu_srcs) - - # NCCL is not support in Windows build - if (WIN32 OR NOT onnxruntime_USE_NCCL) - list(REMOVE_ITEM onnxruntime_rocm_generated_training_ops_cc_srcs - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_common.cc" - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_kernels.cc" - "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/megatron.cc" - ) - endif() - - source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) - list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_training_ops_cc_srcs} ${onnxruntime_rocm_generated_training_ops_cu_srcs}) - endif() - - auto_set_source_files_hip_language(${onnxruntime_providers_rocm_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_rocm ${onnxruntime_providers_rocm_src}) - target_compile_options(onnxruntime_providers_rocm PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) - - if(NOT MSVC) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-sign-compare) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-unused-parameter) - target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template) - endif() - - onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - if (onnxruntime_ENABLE_TRAINING_OPS) - onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_training) - target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_training) - if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - onnxruntime_add_include_to_target(onnxruntime_providers_rocm Python::Module) - endif() - endif() - - add_custom_target(generate_hipified_files DEPENDS - ${onnxruntime_providers_rocm_generated_cc_srcs} - ${onnxruntime_providers_rocm_generated_cu_srcs} - ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} - ${onnxruntime_rocm_generated_contrib_ops_cu_srcs} - ${onnxruntime_rocm_generated_training_ops_cc_srcs} - ${onnxruntime_rocm_generated_training_ops_cu_srcs}) - - add_dependencies(onnxruntime_providers_rocm generate_hipified_files onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROCM_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS}) - target_include_directories(onnxruntime_providers_rocm SYSTEM - PRIVATE - ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime - ${eigen_INCLUDE_DIRS} - PUBLIC - ${onnxruntime_ROCM_HOME}/include - ${onnxruntime_ROCM_HOME}/include/roctracer) - - set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime") - - if (onnxruntime_ENABLE_TRAINING) - target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS}) - if(onnxruntime_USE_MPI) - target_link_libraries(onnxruntime_providers_rocm PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) - endif() - - # RCCL is enabled by default for ROCM builds - #if (onnxruntime_USE_NCCL) - # target_include_directories(onnxruntime_providers_rocm PRIVATE ${NCCL_INCLUDE_DIRS}) - # target_link_libraries(onnxruntime_providers_rocm PRIVATE ${NCCL_LIBRARIES}) - #endif() - endif() - - if (onnxruntime_USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API) - endif() - - if (onnxruntime_USE_HIPBLASLT) - find_package(hipblaslt REQUIRED) - target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT) - endif() - - if (onnxruntime_USE_TRITON_KERNEL) - # compile triton kernel, generate .a and .h files - include(onnxruntime_compile_triton_kernel.cmake) - compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) - add_dependencies(onnxruntime_providers_rocm onnxruntime_triton_kernel) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_TRITON_KERNEL) - target_include_directories(onnxruntime_providers_rocm PRIVATE ${triton_kernel_header_dir}) - target_link_libraries(onnxruntime_providers_rocm PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) - endif() - - if (onnxruntime_USE_COMPOSABLE_KERNEL) - include(composable_kernel) - target_link_libraries(onnxruntime_providers_rocm PRIVATE - onnxruntime_composable_kernel_includes - # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which - # are extremely slow to compile. Instead, we only link all gemm related objects. See the following directory on - # updating. - # https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/library/src/tensor_operation_instance/gpu - device_gemm_instance - device_gemm_add_fastgelu_instance - device_gemm_fastgelu_instance - device_gemm_splitk_instance - device_gemm_streamk_instance - device_batched_gemm_instance - device_softmax_instance - ) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL) - endif() - - if(UNIX) - set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp) - else() - message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") - endif() - - if (onnxruntime_ENABLE_ATEN) - target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) - endif() - - install(TARGETS onnxruntime_providers_rocm - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) - + include(onnxruntime_providers_rocm.cmake) endif() if (onnxruntime_USE_TVM) - add_definitions(-DUSE_TVM=1) - if (onnxruntime_TVM_USE_HASH) - add_definitions(-DUSE_TVM_HASH=1) - endif() - - if (onnxruntime_TVM_USE_HASH) - file (GLOB_RECURSE onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - else() - file (GLOB onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - endif() - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tvm_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_tvm ${onnxruntime_providers_tvm_cc_srcs}) - - if ( CMAKE_COMPILER_IS_GNUCC ) - target_compile_options(onnxruntime_providers_tvm PRIVATE -Wno-unused-parameter -Wno-missing-field-initializers) - endif() - - target_include_directories(onnxruntime_providers_tvm PRIVATE - ${TVM_INCLUDES} - ${PYTHON_INLCUDE_DIRS}) - onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - - add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - if (onnxruntime_TVM_USE_HASH) - add_dependencies(onnxruntime_providers_tvm ippcp_s) - target_include_directories(onnxruntime_providers_tvm PRIVATE ${IPP_CRYPTO_INCLUDE_DIR}) - target_link_libraries(onnxruntime_providers_tvm PRIVATE ippcp_s) - endif() - - set_target_properties(onnxruntime_providers_tvm PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_tvm PROPERTIES LINKER_LANGUAGE CXX) - - if (WIN32 AND MSVC) - # wd4100: identifier' : unreferenced formal parameter - # wd4127: conditional expression is constant - # wd4244: conversion from 'int' to 'char', possible loss of data - # TODO: 4244 should not be disabled - target_compile_options(onnxruntime_providers_tvm PRIVATE "/wd4100" "/wd4127" "/wd4244") - else() - target_compile_options(onnxruntime_providers_tvm PRIVATE "-Wno-error=type-limits") - endif() - target_compile_definitions(onnxruntime_providers_tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) - - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tvm/tvm_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_tvm - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + include(onnxruntime_providers_tvm.cmake) endif() if (onnxruntime_USE_XNNPACK) - add_compile_definitions(USE_XNNPACK=1) - - file(GLOB_RECURSE onnxruntime_providers_xnnpack_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" - # utils for handling QDQ models - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" - ) - - source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface - ) - - add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") - - set_target_properties(onnxruntime_providers_xnnpack PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_xnnpack - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() - - # TODO fix shorten-64-to-32 warnings - # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' - if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32) - endif() + include(onnxruntime_providers_xnnpack.cmake) endif() if (onnxruntime_USE_CANN) - add_definitions(-DUSE_CANN=1) - - file(GLOB_RECURSE onnxruntime_providers_cann_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cann/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cann/*.cc" - ) - - # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. - file(GLOB_RECURSE onnxruntime_providers_cann_shared_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - ) - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs}) - set(onnxruntime_providers_cann_src ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs}) - - onnxruntime_add_shared_library_module(onnxruntime_providers_cann ${onnxruntime_providers_cann_src}) - onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - - add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) - target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64) - target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include) - - set_target_properties(onnxruntime_providers_cann PROPERTIES LINKER_LANGUAGE CXX) - set_target_properties(onnxruntime_providers_cann PROPERTIES FOLDER "ONNXRuntime") - - install(TARGETS onnxruntime_providers_cann - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + include(onnxruntime_providers_cann.cmake) endif() if (onnxruntime_USE_AZURE) - - file(GLOB_RECURSE onnxruntime_providers_azure_src CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/azure/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/azure/*.cc" - ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_azure_src}) - onnxruntime_add_static_library(onnxruntime_providers_azure ${onnxruntime_providers_azure_src}) - add_dependencies(onnxruntime_providers_azure ${onnxruntime_EXTERNAL_DEPENDENCIES}) - onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11) - target_link_libraries(onnxruntime_providers_azure PRIVATE onnx onnxruntime_common onnxruntime_framework) - set_target_properties(onnxruntime_providers_azure PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_azure PROPERTIES LINKER_LANGUAGE CXX) - - install(TARGETS onnxruntime_providers_azure - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + include(onnxruntime_providers_azure.cmake) endif() diff --git a/cmake/onnxruntime_providers_acl.cmake b/cmake/onnxruntime_providers_acl.cmake new file mode 100644 index 0000000000000..e23d2892713fc --- /dev/null +++ b/cmake/onnxruntime_providers_acl.cmake @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_ACL=1) + file(GLOB_RECURSE onnxruntime_providers_acl_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/acl/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/acl/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_acl_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_acl ${onnxruntime_providers_acl_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_acl + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + + target_link_libraries(onnxruntime_providers_acl -L$ENV{LD_LIBRARY_PATH}) + add_dependencies(onnxruntime_providers_acl ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_acl PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_acl + PRIVATE + ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_ACL_HOME} ${onnxruntime_ACL_HOME}/include + ) + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/acl/acl_provider_factory.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/ + ) + set_target_properties(onnxruntime_providers_acl PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_acl + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_armnn.cmake b/cmake/onnxruntime_providers_armnn.cmake new file mode 100644 index 0000000000000..33fadb7c64c2e --- /dev/null +++ b/cmake/onnxruntime_providers_armnn.cmake @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_ARMNN=1) + file(GLOB_RECURSE onnxruntime_providers_armnn_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/armnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/armnn/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_armnn_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_armnn ${onnxruntime_providers_armnn_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_armnn + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + + add_dependencies(onnxruntime_providers_armnn ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_armnn PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_armnn PRIVATE + ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_ARMNN_HOME} ${onnxruntime_ARMNN_HOME}/include + ${onnxruntime_ACL_HOME} ${onnxruntime_ACL_HOME}/include + ) + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/armnn/armnn_provider_factory.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) + + set_target_properties(onnxruntime_providers_armnn PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_armnn + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_azure.cmake b/cmake/onnxruntime_providers_azure.cmake new file mode 100644 index 0000000000000..1b03563848556 --- /dev/null +++ b/cmake/onnxruntime_providers_azure.cmake @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + file(GLOB_RECURSE onnxruntime_providers_azure_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/azure/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/azure/*.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_azure_src}) + onnxruntime_add_static_library(onnxruntime_providers_azure ${onnxruntime_providers_azure_src}) + add_dependencies(onnxruntime_providers_azure ${onnxruntime_EXTERNAL_DEPENDENCIES}) + onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11) + target_link_libraries(onnxruntime_providers_azure PRIVATE onnx onnxruntime_common onnxruntime_framework) + set_target_properties(onnxruntime_providers_azure PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_azure PROPERTIES LINKER_LANGUAGE CXX) + + install(TARGETS onnxruntime_providers_azure + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_cann.cmake b/cmake/onnxruntime_providers_cann.cmake new file mode 100644 index 0000000000000..0e26f7ee3a57b --- /dev/null +++ b/cmake/onnxruntime_providers_cann.cmake @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_CANN=1) + + file(GLOB_RECURSE onnxruntime_providers_cann_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cann/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cann/*.cc" + ) + + # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. + file(GLOB_RECURSE onnxruntime_providers_cann_shared_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs}) + set(onnxruntime_providers_cann_src ${onnxruntime_providers_cann_cc_srcs} ${onnxruntime_providers_cann_shared_srcs}) + + onnxruntime_add_shared_library_module(onnxruntime_providers_cann ${onnxruntime_providers_cann_src}) + onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) + + add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) + target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64) + target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include) + + set_target_properties(onnxruntime_providers_cann PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_cann PROPERTIES FOLDER "ONNXRuntime") + + install(TARGETS onnxruntime_providers_cann + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake new file mode 100644 index 0000000000000..aa8c35526b274 --- /dev/null +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + add_compile_definitions(USE_COREML=1) + + # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) + file(GLOB coreml_proto_srcs + "${COREML_PROTO_ROOT}/*.proto" + ) + onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) + target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") + target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) + set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") + set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") + set(_src_sub_dir "coreml/") + onnxruntime_protobuf_generate( + APPEND_PATH + GEN_SRC_SUB_DIR ${_src_sub_dir} + IMPORT_DIRS ${COREML_PROTO_ROOT} + TARGET onnxruntime_coreml_proto + ) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_coreml_proto + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + endif() + endif() + + # These are shared utils, + # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + + file(GLOB + onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" + ) + + # Add builder source code + file(GLOB_RECURSE + onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" + ) + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") + list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" + ) + endif() + + # Add CoreML objective c++ source code + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" + ) + endif() + + set(onnxruntime_providers_coreml_cc_srcs + ${onnxruntime_providers_coreml_cc_srcs_top} + ${onnxruntime_providers_coreml_cc_srcs_nested} + ${onnxruntime_providers_shared_utils_cc_srcs} + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_coreml + ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + ) + onnxruntime_add_include_to_target(onnxruntime_providers_coreml + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) + target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") + add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) + endif() + add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + + set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) + set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) + set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_coreml + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake new file mode 100644 index 0000000000000..f60faa4d39116 --- /dev/null +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +file(GLOB_RECURSE onnxruntime_providers_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cpu/*.cc" +) + +if(onnxruntime_DISABLE_ML_OPS) + list(FILTER onnxruntime_providers_srcs EXCLUDE REGEX ".*/ml/.*") +endif() + +file(GLOB_RECURSE onnxruntime_cpu_contrib_ops_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.cc" +) + +file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" +) + +file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh" +) + +file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cc" +) + +file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cu" + "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh" +) + +file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" +) + +file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/op_kernel_type_control_overrides.inc" +) + + +source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) + +set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs}) + +# disable contrib ops conditionally +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_ENABLE_ATEN) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" + ) + endif() + # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) +endif() + +if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" + ) + + source_group(TREE ${ORTTRAINING_ROOT}/ FILES ${onnxruntime_cpu_training_ops_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_training_ops_srcs}) + + file(GLOB_RECURSE onnxruntime_cpu_full_training_only_srcs + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/communication/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/communication/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/record.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/record.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/wait.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/wait.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/yield.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/controlflow/yield.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/triton/triton_op.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/triton/triton_op.h" + ) + + list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs}) +endif() + +if (onnxruntime_ENABLE_ATEN) + file(GLOB_RECURSE onnxruntime_providers_dlpack_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.cc" + "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.h" + ) + set(onnxruntime_providers_dlpack_srcs ${onnxruntime_providers_dlpack_srcs}) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dlpack_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_providers_dlpack_srcs}) +endif() + +if (onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/framework/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/framework/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/framework/adasum/*" + "${ORTTRAINING_SOURCE_DIR}/core/framework/communication/*" + ) + + # This is already built in framework.cmake + file(GLOB_RECURSE onnxruntime_training_framework_excude_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/framework/triton/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/framework/triton/*.cc" + ) + + list(REMOVE_ITEM onnxruntime_cpu_training_ops_srcs ${onnxruntime_training_framework_excude_srcs}) + + source_group(TREE ${ORTTRAINING_ROOT}/ FILES ${onnxruntime_cpu_training_ops_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_training_ops_srcs}) +endif() + +if (onnxruntime_REDUCED_OPS_BUILD) + substitute_op_reduction_srcs(onnxruntime_providers_src) +endif() +onnxruntime_add_static_library(onnxruntime_providers ${onnxruntime_providers_src}) +if (onnxruntime_REDUCED_OPS_BUILD) + add_op_reduction_include_dirs(onnxruntime_providers) +endif() + +if (HAS_BITWISE_INSTEAD_OF_LOGICAL) + target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") +endif() + +if (MSVC) + target_compile_options(onnxruntime_providers PRIVATE "/bigobj") +# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) +# target_compile_options(onnxruntime_providers PRIVATE "/wd4244") +# endif() +endif() +onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) + +if (onnxruntime_BUILD_MS_EXPERIMENTAL_OPS) + target_compile_definitions(onnxruntime_providers PRIVATE BUILD_MS_EXPERIMENTAL_OPS=1) +endif() + +if(HAS_DEPRECATED_COPY) + #temporarily ignore this warning + #see: https://en.wikipedia.org/wiki/Rule_of_three_(C%2B%2B_programming) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/matmul_integer.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/math/quantize_linear_matmul.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/qlinearconv.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/nn/conv_integer.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/generator/random.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/onehot.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) + set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/where_op.cc" PROPERTIES COMPILE_FLAGS -Wno-deprecated-copy) +endif() + +# This is enabled only for Adasum files in training mode. +# The flags won't be applied globally since some high-precision training and inferencing ops will incur precision loss. +if (onnxruntime_ENABLE_CPU_FP16_OPS) + set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/core/framework/adasum/adasum_mpi.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") + set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") + set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") +endif() + +target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) +onnxruntime_add_include_to_target(onnxruntime_providers re2::re2) +add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(onnxruntime_providers PRIVATE ${ORTTRAINING_ROOT}) +endif() + +if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_ATEN) + # DLPack is a header-only dependency + set(DLPACK_INCLUDE_DIR ${dlpack_SOURCE_DIR}/include) + target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR}) +endif() + +if (onnxruntime_ENABLE_TRAINING) + add_dependencies(onnxruntime_providers tensorboard) + onnxruntime_add_include_to_target(onnxruntime_providers tensorboard) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OR onnxruntime_ENABLE_TRITON) + onnxruntime_add_include_to_target(onnxruntime_providers Python::Module) + endif() + + if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI) + target_include_directories(onnxruntime_providers PUBLIC ${MPI_CXX_INCLUDE_DIRS}) + endif() +endif() + +install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu/cpu_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) +set_target_properties(onnxruntime_providers PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(onnxruntime_providers PROPERTIES FOLDER "ONNXRuntime") + +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD + AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS" + AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android" + AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + file(GLOB onnxruntime_providers_shared_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_shared_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_shared ${onnxruntime_providers_shared_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") + set_target_properties(onnxruntime_providers_shared PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_shared PROPERTIES LINKER_LANGUAGE CXX) + + target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) + target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MINOR=${VERSION_MINOR_PART}) + target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_BUILD=${VERSION_BUILD_PART}) + target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) + target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_STRING=\"${VERSION_STRING}\") + target_compile_definitions(onnxruntime_providers_shared PRIVATE FILE_NAME=\"onnxruntime_providers_shared.dll\") + + + # On Apple/Unix we don't directly link with this library as we load it with RTLD_GLOBAL, so this is only set to the actual library on WIN32 + # But, in exchange we need to manually add Boost::mp11 to include dirs for every EP. + # It is because "provider_api.h" includes core/framework/op_kernel.h which includes op_kernel.h which includes "boost/mp11.hpp" + set(ONNXRUNTIME_PROVIDERS_SHARED) + + if(APPLE) + set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst") + elseif(UNIX) + set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections") + elseif(WIN32) + set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def") + set(ONNXRUNTIME_PROVIDERS_SHARED onnxruntime_providers_shared) + else() + message(FATAL_ERROR "onnxruntime_providers_shared unknown platform, need to specify shared library exports for it") + endif() + + install(TARGETS onnxruntime_providers_shared + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake new file mode 100644 index 0000000000000..cf298aee9fa85 --- /dev/null +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + ) + # Remove pch files + list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" + ) + + # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. + file(GLOB_RECURSE onnxruntime_providers_cuda_shared_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) + set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) + + # disable contrib ops conditionally + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_ENABLE_ATEN) + list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" + ) + endif() + if (NOT onnxruntime_USE_NCCL) + list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_unsqueeze.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_squeeze.cc" + ) + endif() + # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) + list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) + endif() + + if (onnxruntime_ENABLE_TRAINING_OPS) + file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" + ) + + file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh" + ) + + source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) + list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) + + if(NOT onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE onnxruntime_cuda_full_training_only_srcs + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cu" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/triton/triton_op.cc" + ) + + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_full_training_only_srcs}) + elseif(WIN32 OR NOT onnxruntime_USE_NCCL) + # NCCL is not support in Windows build + file(GLOB_RECURSE onnxruntime_cuda_nccl_op_srcs + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" + ) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs}) + endif() + endif() + + if (onnxruntime_REDUCED_OPS_BUILD) + substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) + endif() + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and + # add to the lib onnxruntime_providers_cuda separatedly. + # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. + set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) + onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + else() + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + endif() + # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. + # This function guarantees that all 3 targets have the same configurations. + function(config_cuda_provider_shared_module target) + if (onnxruntime_REDUCED_OPS_BUILD) + add_op_reduction_include_dirs(${target}) + endif() + + if (HAS_GUARD_CF) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") + endif() + if (HAS_QSPECTRE) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") + endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") + endforeach() + # CUDA 11.3+ supports parallel compilation + # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) + option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) + target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") + endif() + if (UNIX) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" + "$<$>:-Wno-reorder>") + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + "$<$>:-Wno-error=sign-compare>") + else() + #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + endif() + + onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) + if (onnxruntime_ENABLE_TRAINING_OPS) + onnxruntime_add_include_to_target(${target} onnxruntime_training) + if (onnxruntime_ENABLE_TRAINING) + target_link_libraries(${target} PRIVATE onnxruntime_training) + endif() + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OR onnxruntime_ENABLE_TRITON) + onnxruntime_add_include_to_target(${target} Python::Module) + endif() + endif() + + add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + if(onnxruntime_CUDNN_HOME) + target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) + target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + endif() + + if (onnxruntime_USE_TRITON_KERNEL) + # compile triton kernel, generate .a and .h files + include(onnxruntime_compile_triton_kernel.cmake) + compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) + add_dependencies(${target} onnxruntime_triton_kernel) + target_compile_definitions(${target} PRIVATE USE_TRITON_KERNEL) + target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) + target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) + # lib cuda needed by cuLaunchKernel + target_link_libraries(${target} PRIVATE cuda) + endif() + + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found + set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) + set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") + + if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling + target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) + target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) + target_link_libraries(${target} PRIVATE cupti) + endif() + + if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) + target_link_libraries(${target} PRIVATE nvToolsExt) + endif() + + if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(${target} PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS}) + endif() + + if(onnxruntime_USE_MPI) + target_link_libraries(${target} PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) + endif() + + if (onnxruntime_USE_NCCL) + target_include_directories(${target} PRIVATE ${NCCL_INCLUDE_DIRS}) + target_link_libraries(${target} PRIVATE ${NCCL_LIBRARIES}) + endif() + + if (WIN32) + # *.cu cannot use PCH + if (NOT onnxruntime_BUILD_CACHE) + target_precompile_headers(${target} PUBLIC + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" + ) + endif() + + # minimize the Windows includes. + # this avoids an issue with CUDA 11.6 where 'small' is defined in the windows and cuda headers. + target_compile_definitions(${target} PRIVATE "WIN32_LEAN_AND_MEAN") + + # disable a warning from the CUDA headers about unreferenced local functions + #target_compile_options(${target} PRIVATE /wd4505) + set(onnxruntime_providers_cuda_static_library_flags + -IGNORE:4221 # LNK4221: This object file does not define any previously undefined public symbols, so it will not be used by any link operation that consumes this library + ) + set_target_properties(${target} PROPERTIES + STATIC_LIBRARY_FLAGS "${onnxruntime_providers_cuda_static_library_flags}") + endif() + + if(APPLE) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst") + target_link_libraries(${target} PRIVATE nsync::nsync_cpp) + elseif(UNIX) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections") + target_link_libraries(${target} PRIVATE nsync::nsync_cpp) + elseif(WIN32) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def") + else() + message(FATAL_ERROR "${target} unknown platform, need to specify shared library exports for it") + endif() + + if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(${target} PRIVATE ENABLE_ATEN) + endif() + endfunction() + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + endif() + config_cuda_provider_shared_module(onnxruntime_providers_cuda) + + install(TARGETS onnxruntime_providers_cuda + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake new file mode 100644 index 0000000000000..01b0bda9fea6b --- /dev/null +++ b/cmake/onnxruntime_providers_dml.cmake @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + file(GLOB_RECURSE onnxruntime_providers_dml_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/dml/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/dml/*.cpp" + "${ONNXRUNTIME_ROOT}/core/providers/dml/*.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dml_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_dml ${onnxruntime_providers_dml_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_dml + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${WIL_TARGET} + ) + add_dependencies(onnxruntime_providers_dml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnxruntime_providers_dml PRIVATE + ${ONNXRUNTIME_ROOT} + ) + + target_compile_definitions(onnxruntime_providers_dml PRIVATE DML_TARGET_VERSION_USE_LATEST=1) + if(WIN32) + target_compile_options(onnxruntime_providers_dml PRIVATE "/wd4100" "/wd4238" "/wd4189" "/wd4702") + endif() + + if (NOT onnxruntime_USE_CUSTOM_DIRECTML) + foreach(file "DirectML.dll" "DirectML.pdb" "DirectML.Debug.dll" "DirectML.Debug.pdb") + add_custom_command(TARGET onnxruntime_providers_dml + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/${file}" $) + endforeach() + endif() + + function(target_add_dml target) + if (onnxruntime_USE_CUSTOM_DIRECTML) + if (dml_EXTERNAL_PROJECT) + # Internal build of DirectML: link against the "DirectML" target. + target_link_libraries(${target} PRIVATE DirectML) + else() + if (dml_LIB_DIR) + target_link_libraries(${target} PRIVATE ${dml_LIB_DIR}/DirectML.lib) + else() + target_link_libraries(${target} PRIVATE DirectML) + endif() + endif() + else() + add_dependencies(${target} RESTORE_PACKAGES) + target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/DirectML.lib") + target_compile_definitions(${target} PRIVATE DML_TARGET_VERSION_USE_LATEST) + endif() + endfunction() + + target_add_dml(onnxruntime_providers_dml) + target_link_libraries(onnxruntime_providers_dml PRIVATE onnxruntime_common) + target_link_libraries(onnxruntime_providers_dml PRIVATE onnxruntime_framework) + onnxruntime_add_include_to_target(onnxruntime_providers_dml onnxruntime_common) + if (GDK_PLATFORM STREQUAL Scarlett) + target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) + else() + target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib) + endif() + + target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) + + if (NOT GDK_PLATFORM) + set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") + endif() + + target_compile_definitions(onnxruntime_providers_dml + PRIVATE + ONNX_NAMESPACE=onnx ONNX_ML LOTUS_LOG_THRESHOLD=2 LOTUS_ENABLE_STDERR_LOGGING PLATFORM_WINDOWS + ) + target_compile_definitions(onnxruntime_providers_dml PRIVATE UNICODE _UNICODE NOMINMAX) + if (MSVC) + target_compile_definitions(onnxruntime_providers_dml PRIVATE _SILENCE_CXX17_ITERATOR_BASE_CLASS_DEPRECATION_WARNING) + target_compile_options(onnxruntime_providers_dml PRIVATE "/W3") + endif() + + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dml/dml_provider_factory.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/ + ) + + set_target_properties(onnxruntime_providers_dml PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime") + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_dml + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() diff --git a/cmake/onnxruntime_providers_dnnl.cmake b/cmake/onnxruntime_providers_dnnl.cmake new file mode 100644 index 0000000000000..f2965728524b7 --- /dev/null +++ b/cmake/onnxruntime_providers_dnnl.cmake @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + file(GLOB_RECURSE onnxruntime_providers_dnnl_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/dnnl/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/dnnl/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dnnl_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_dnnl ${onnxruntime_providers_dnnl_cc_srcs}) + target_link_directories(onnxruntime_providers_dnnl PRIVATE ${DNNL_LIB_DIR}) + if (MSVC AND onnxruntime_ENABLE_STATIC_ANALYSIS) + # dnnl_convgrad.cc(47,0): Warning C6262: Function uses '38816' bytes of stack: exceeds /analyze:stacksize '16384'. Consider moving some data to heap. + target_compile_options(onnxruntime_providers_dnnl PRIVATE "/analyze:stacksize 131072") + endif() + + add_dependencies(onnxruntime_providers_dnnl onnxruntime_providers_shared project_dnnl ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${DNNL_INCLUDE_DIR} ${DNNL_OCL_INCLUDE_DIR}) + # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found + target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET} safeint_interface) + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dnnl/dnnl_provider_options.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) + set_target_properties(onnxruntime_providers_dnnl PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX) + + # Needed for the provider interface, as it includes training headers when training is enabled + if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ORTTRAINING_ROOT}) + endif() + + # Needed for threadpool handling + if(onnxruntime_BUILD_JAVA) + add_compile_definitions(DNNL_JAVA) + endif() + + if(APPLE) + set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/dnnl/exported_symbols.lst") + set_target_properties(onnxruntime_providers_dnnl PROPERTIES + INSTALL_RPATH "@loader_path" + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE) + target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) + elseif(UNIX) + set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN") + target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) + elseif(WIN32) + set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def") + else() + message(FATAL_ERROR "onnxruntime_providers_dnnl unknown platform, need to specify shared library exports for it") + endif() + + install(TARGETS onnxruntime_providers_dnnl + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_js.cmake b/cmake/onnxruntime_providers_js.cmake new file mode 100644 index 0000000000000..306f5c74cb4c6 --- /dev/null +++ b/cmake/onnxruntime_providers_js.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_compile_definitions(USE_JSEP=1) + + file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/js/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/js/*.cc" + ) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_js_contrib_ops_cc_srcs}) + list(APPEND onnxruntime_providers_js_cc_srcs ${onnxruntime_js_contrib_ops_cc_srcs}) + endif() + + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_js_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_js + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 + ) + target_include_directories(onnxruntime_providers_js PRIVATE ${eigen_INCLUDE_DIRS}) + add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake new file mode 100644 index 0000000000000..91ac66a40721d --- /dev/null +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_MIGRAPHX=1) + set(BUILD_LIBRARY_ONLY 1) + add_definitions("-DONNX_ML=1") + add_definitions("-DONNX_NAMESPACE=onnx") + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) + set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) + include_directories(${onnx_SOURCE_DIR}) + set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") + endif() + set(CXX_VERSION_DEFINED TRUE) + set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + endif() + + # Add search paths for default rocm installation + list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) + + find_package(hip) + find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME}) + + find_package(miopen) + find_package(rocblas) + + set(migraphx_libs migraphx::c hip::host MIOpen roc::rocblas) + + file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.h" + "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) + + include(CheckLibraryExists) + check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) + if(HAS_STREAM_SYNC) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) + message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") + else() + message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") + endif() + + if (onnxruntime_ENABLE_TRAINING_OPS) + onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_providers_migraphx Python::Module) + endif() + endif() + + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake new file mode 100644 index 0000000000000..5ac25a3b76efb --- /dev/null +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "NNAPI can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + add_compile_definitions(USE_NNAPI=1) + + # This is the minimum Android API Level required by ORT NNAPI EP to run + # ORT running on any host system with Android API level less than this will fall back to CPU EP + if(onnxruntime_NNAPI_MIN_API) + add_compile_definitions(ORT_NNAPI_MIN_API_LEVEL=${onnxruntime_NNAPI_MIN_API}) + endif() + + # This is the maximum Android API level supported in the ort model conversion for NNAPI EP + # Note: This is only for running NNAPI for ort format model conversion on non-Android system since we cannot + # get the actually Android system version. + if(onnxruntime_NNAPI_HOST_API) + if(CMAKE_SYSTEM_NAME STREQUAL "Android") + message(FATAL_ERROR "onnxruntime_NNAPI_HOST_API should only be set for non-Android target") + endif() + add_compile_definitions(ORT_NNAPI_MAX_SUPPORTED_API_LEVEL=${onnxruntime_NNAPI_HOST_API}) + endif() + + set(onnxruntime_provider_nnapi_cc_src_patterns + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/impl/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/builders/impl/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.cc" + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" + ) + + # On Android, use the actual NNAPI implementation. + # Otherwise, use a stub implementation to support some unit testing. + if(CMAKE_SYSTEM_NAME STREQUAL "Android") + list(APPEND onnxruntime_provider_nnapi_cc_src_patterns + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.cc") + else() + list(APPEND onnxruntime_provider_nnapi_cc_src_patterns + "${ONNXRUNTIME_ROOT}/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation_stub.cc") + endif() + + # These are shared utils, + # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + list(APPEND onnxruntime_provider_nnapi_cc_src_patterns + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" + ) + + file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_nnapi_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_nnapi ${onnxruntime_providers_nnapi_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_nnapi + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + target_link_libraries(onnxruntime_providers_nnapi) + add_dependencies(onnxruntime_providers_nnapi onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_nnapi PROPERTIES CXX_STANDARD_REQUIRED ON) + set_target_properties(onnxruntime_providers_nnapi PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_nnapi PRIVATE ${ONNXRUNTIME_ROOT} ${nnapi_INCLUDE_DIRS}) + set_target_properties(onnxruntime_providers_nnapi PROPERTIES LINKER_LANGUAGE CXX) + # ignore the warning unknown-pragmas on "pragma region" + if(NOT MSVC) + target_compile_options(onnxruntime_providers_nnapi PRIVATE "-Wno-unknown-pragmas") + endif() + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_nnapi + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake new file mode 100644 index 0000000000000..e26f0bfc0b751 --- /dev/null +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# include_directories("${CMAKE_CURRENT_BINARY_DIR}/onnx") + file(GLOB_RECURSE onnxruntime_providers_openvino_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.hpp" + "${ONNXRUNTIME_ROOT}/core/providers/openvino/*.cpp" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + + if (WIN32) + set(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release) + endif() + + # Header paths + find_package(InferenceEngine REQUIRED) + find_package(ngraph REQUIRED) + + if (OPENVINO_2022_1 OR OPENVINO_2022_2) + find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) + list (OV_20_LIBS openvino::frontend::onnx openvino::runtime) + endif() + + if (WIN32) + unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) + endif() + + if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS})) + add_definitions(-DIO_BUFFER_ENABLED=1) + list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS} ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) + else() + list(APPEND OPENVINO_LIB_LIST ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) + endif() + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") + onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx) + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/openvino/openvino_provider_factory.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) + set_target_properties(onnxruntime_providers_openvino PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_openvino PROPERTIES FOLDER "ONNXRuntime") + if(NOT MSVC) + target_compile_options(onnxruntime_providers_openvino PRIVATE "-Wno-parentheses") + endif() + add_dependencies(onnxruntime_providers_openvino onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) + target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS}) + + target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) + target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MINOR=${VERSION_MINOR_PART}) + target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_BUILD=${VERSION_BUILD_PART}) + target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) + target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_STRING=\"${VERSION_STRING}\") + target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") + + if(MSVC) + target_compile_options(onnxruntime_providers_openvino PUBLIC /wd4099 /wd4275 /wd4100 /wd4005 /wd4244 /wd4267) + endif() + + # Needed for the provider interface, as it includes training headers when training is enabled + if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(onnxruntime_providers_openvino PRIVATE ${ORTTRAINING_ROOT}) + endif() + + if(APPLE) + set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/openvino/exported_symbols.lst") + elseif(UNIX) + set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/openvino/version_script.lds -Xlinker --gc-sections") + elseif(WIN32) + set_property(TARGET onnxruntime_providers_openvino APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/openvino/symbols.def") + else() + message(FATAL_ERROR "onnxruntime_providers_openvino unknown platform, need to specify shared library exports for it") + endif() + + install(TARGETS onnxruntime_providers_openvino + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake new file mode 100644 index 0000000000000..a93a06e960c81 --- /dev/null +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_compile_definitions(USE_QNN=1) + + # These are shared utils, + # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML + file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" + ) + + file(GLOB_RECURSE + onnxruntime_providers_qnn_ep_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.cc" + ) + + file(GLOB_RECURSE + onnxruntime_providers_qnn_builder_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.cc" + ) + + set(onnxruntime_providers_qnn_cc_srcs + ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_qnn_ep_cc_srcs} + ${onnxruntime_providers_qnn_builder_cc_srcs} + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11) + target_link_libraries(onnxruntime_providers_qnn) + add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) + set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_QNN_HOME}/include/QNN ${onnxruntime_QNN_HOME}/include) + set_target_properties(onnxruntime_providers_qnn PROPERTIES LINKER_LANGUAGE CXX) + # ignore the warning unknown-pragmas on "pragma region" + if(NOT MSVC) + target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_rknpu.cmake b/cmake/onnxruntime_providers_rknpu.cmake new file mode 100644 index 0000000000000..408bcfde06c36 --- /dev/null +++ b/cmake/onnxruntime_providers_rknpu.cmake @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable -Wno-unused-parameter") + add_definitions(-DUSE_RKNPU=1) + option(DNN_READ_ONNX "" ON) + set(DNN_CUSTOM_PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + option(DNN_CMAKE_INSTALL "" OFF) + option(DNN_BUILD_BIN "" OFF) + if (NOT RKNPU_DDK_PATH) + message(FATAL_ERROR "RKNPU_DDK_PATH required for onnxruntime_USE_RKNPU") + endif() + set(RKNPU_DDK_INCLUDE_DIR ${RKNPU_DDK_PATH}/include) + if (CMAKE_SIZEOF_VOID_P EQUAL 8) + set(RKNPU_DDK_LIB_DIR ${RKNPU_DDK_PATH}/lib64) + else() + set(RKNPU_DDK_LIB_DIR ${RKNPU_DDK_PATH}/lib) + endif() + file(GLOB_RECURSE + onnxruntime_providers_rknpu_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/rknpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/rknpu/*.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_rknpu_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_rknpu ${onnxruntime_providers_rknpu_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_rknpu + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + target_link_libraries(onnxruntime_providers_rknpu PRIVATE -lrknpu_ddk) + add_dependencies(onnxruntime_providers_rknpu onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_rknpu PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_rknpu PRIVATE + ${ONNXRUNTIME_ROOT} ${rknpu_INCLUDE_DIRS} ${RKNPU_DDK_INCLUDE_DIR} + ) + link_directories(onnxruntime_providers_rknpu ${RKNPU_DDK_LIB_DIR}) + set_target_properties(onnxruntime_providers_rknpu PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_rknpu + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake new file mode 100644 index 0000000000000..b66268291579c --- /dev/null +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_ROCM=1) + include(onnxruntime_rocm_hipify.cmake) + + list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME}) + + find_package(HIP) + find_package(hiprand REQUIRED) + find_package(rocblas REQUIRED) + find_package(MIOpen REQUIRED) + find_package(hipfft REQUIRED) + + # MIOpen version + if(NOT DEFINED ENV{MIOPEN_PATH}) + set(MIOPEN_PATH ${onnxruntime_ROCM_HOME}) + else() + set(MIOPEN_PATH $ENV{MIOPEN_PATH}) + endif() + find_path(MIOPEN_VERSION_H_PATH + NAMES version.h + HINTS + ${MIOPEN_PATH}/include/miopen + ${MIOPEN_PATH}/miopen/include) + if (MIOPEN_VERSION_H_PATH-NOTFOUND) + MESSAGE(FATAL_ERROR "miopen version.h not found") + endif() + MESSAGE(STATUS "Found miopen version.h at ${MIOPEN_VERSION_H_PATH}") + + file(READ ${MIOPEN_VERSION_H_PATH}/version.h MIOPEN_HEADER_CONTENTS) + string(REGEX MATCH "define MIOPEN_VERSION_MAJOR * +([0-9]+)" + MIOPEN_VERSION_MAJOR "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_MAJOR * +([0-9]+)" "\\1" + MIOPEN_VERSION_MAJOR "${MIOPEN_VERSION_MAJOR}") + string(REGEX MATCH "define MIOPEN_VERSION_MINOR * +([0-9]+)" + MIOPEN_VERSION_MINOR "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_MINOR * +([0-9]+)" "\\1" + MIOPEN_VERSION_MINOR "${MIOPEN_VERSION_MINOR}") + string(REGEX MATCH "define MIOPEN_VERSION_PATCH * +([0-9]+)" + MIOPEN_VERSION_PATCH "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_PATCH * +([0-9]+)" "\\1" + MIOPEN_VERSION_PATCH "${MIOPEN_VERSION_PATCH}") + set(MIOPEN_VERSION_DEV "${MIOPEN_VERSION_MAJOR}.${MIOPEN_VERSION_MINOR}.${MIOPEN_VERSION_PATCH}") + math(EXPR MIOPEN_VERSION_DEV_INT "(${MIOPEN_VERSION_MAJOR}*10000) + (${MIOPEN_VERSION_MINOR}*100) + ${MIOPEN_VERSION_PATCH}") + message("MIOPEN_VERSION_DEV: ${MIOPEN_VERSION_DEV}") + message("MIOPEN_VERSION_DEV_INT: ${MIOPEN_VERSION_DEV_INT}") + add_definitions(-DMIOPEN_VERSION=${MIOPEN_VERSION_DEV_INT}) + + find_library(RCCL_LIB rccl REQUIRED) + find_library(ROCTRACER_LIB roctracer64 REQUIRED) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) + + file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cc" + ) + + # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. + file(GLOB_RECURSE onnxruntime_providers_rocm_shared_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + + file(GLOB_RECURSE onnxruntime_providers_rocm_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cu" + "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.cuh" + ) + + hipify("onnxruntime/core/providers" provider_excluded_files onnxruntime_providers_rocm_generated_cc_srcs onnxruntime_providers_rocm_generated_cu_srcs) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) + set(onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_cc_srcs} ${onnxruntime_providers_rocm_shared_srcs} ${onnxruntime_providers_rocm_cu_srcs}) + list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_providers_rocm_generated_cc_srcs} ${onnxruntime_providers_rocm_generated_cu_srcs}) + + # disable contrib ops conditionally + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + hipify("onnxruntime/contrib_ops" contrib_ops_excluded_files onnxruntime_rocm_generated_contrib_ops_cc_srcs onnxruntime_rocm_generated_contrib_ops_cu_srcs) + + # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) + list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) + list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} ${onnxruntime_rocm_generated_contrib_ops_cu_srcs}) + endif() + + if (onnxruntime_ENABLE_TRAINING_OPS) + file(GLOB_RECURSE onnxruntime_rocm_training_ops_cc_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cc" + ) + + file(GLOB_RECURSE onnxruntime_rocm_training_ops_cu_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cu" + "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cuh" + ) + + hipify("orttraining/orttraining/training_ops" training_ops_excluded_files onnxruntime_rocm_generated_training_ops_cc_srcs onnxruntime_rocm_generated_training_ops_cu_srcs) + + # NCCL is not support in Windows build + if (WIN32 OR NOT onnxruntime_USE_NCCL) + list(REMOVE_ITEM onnxruntime_rocm_generated_training_ops_cc_srcs + "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_common.cc" + "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/nccl_kernels.cc" + "${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining/orttraining/training_ops/rocm/collective/megatron.cc" + ) + endif() + + source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) + list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_training_ops_cc_srcs} ${onnxruntime_rocm_training_ops_cu_srcs}) + list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_training_ops_cc_srcs} ${onnxruntime_rocm_generated_training_ops_cu_srcs}) + endif() + + auto_set_source_files_hip_language(${onnxruntime_providers_rocm_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_rocm ${onnxruntime_providers_rocm_src}) + target_compile_options(onnxruntime_providers_rocm PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) + + if(NOT MSVC) + target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-sign-compare) + target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-unused-parameter) + target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template) + endif() + + onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) + if (onnxruntime_ENABLE_TRAINING_OPS) + onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_training) + target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_training) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_providers_rocm Python::Module) + endif() + endif() + + add_custom_target(generate_hipified_files DEPENDS + ${onnxruntime_providers_rocm_generated_cc_srcs} + ${onnxruntime_providers_rocm_generated_cu_srcs} + ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} + ${onnxruntime_rocm_generated_contrib_ops_cu_srcs} + ${onnxruntime_rocm_generated_training_ops_cc_srcs} + ${onnxruntime_rocm_generated_training_ops_cu_srcs}) + + add_dependencies(onnxruntime_providers_rocm generate_hipified_files onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROCM_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS}) + target_include_directories(onnxruntime_providers_rocm SYSTEM + PRIVATE + ${ONNXRUNTIME_ROOT} + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime + ${eigen_INCLUDE_DIRS} + PUBLIC + ${onnxruntime_ROCM_HOME}/include + ${onnxruntime_ROCM_HOME}/include/roctracer) + + set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime") + + if (onnxruntime_ENABLE_TRAINING) + target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS}) + if(onnxruntime_USE_MPI) + target_link_libraries(onnxruntime_providers_rocm PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) + endif() + + # RCCL is enabled by default for ROCM builds + #if (onnxruntime_USE_NCCL) + # target_include_directories(onnxruntime_providers_rocm PRIVATE ${NCCL_INCLUDE_DIRS}) + # target_link_libraries(onnxruntime_providers_rocm PRIVATE ${NCCL_LIBRARIES}) + #endif() + endif() + + if (onnxruntime_USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API) + endif() + + if (onnxruntime_USE_HIPBLASLT) + find_package(hipblaslt REQUIRED) + target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT) + endif() + + if (onnxruntime_USE_TRITON_KERNEL) + # compile triton kernel, generate .a and .h files + include(onnxruntime_compile_triton_kernel.cmake) + compile_triton_kernel(triton_kernel_obj_file triton_kernel_header_dir) + add_dependencies(onnxruntime_providers_rocm onnxruntime_triton_kernel) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_TRITON_KERNEL) + target_include_directories(onnxruntime_providers_rocm PRIVATE ${triton_kernel_header_dir}) + target_link_libraries(onnxruntime_providers_rocm PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) + endif() + + if (onnxruntime_USE_COMPOSABLE_KERNEL) + include(composable_kernel) + target_link_libraries(onnxruntime_providers_rocm PRIVATE + onnxruntime_composable_kernel_includes + # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which + # are extremely slow to compile. Instead, we only link all gemm related objects. See the following directory on + # updating. + # https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/library/src/tensor_operation_instance/gpu + device_gemm_instance + device_gemm_add_fastgelu_instance + device_gemm_fastgelu_instance + device_gemm_splitk_instance + device_gemm_streamk_instance + device_batched_gemm_instance + device_softmax_instance + ) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL) + endif() + + if(UNIX) + set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp) + else() + message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") + endif() + + if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) + endif() + + install(TARGETS onnxruntime_providers_rocm + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake new file mode 100644 index 0000000000000..686a993de3a4a --- /dev/null +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_TENSORRT=1) + if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER) + add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER) + endif() + set(BUILD_LIBRARY_ONLY 1) + add_definitions("-DONNX_ML=1") + add_definitions("-DONNX_NAMESPACE=onnx") + set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) + set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) + if (WIN32) + add_definitions(-D_SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING=1) + set(OLD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4099 /wd4551 /wd4505 /wd4515 /wd4706 /wd4456 /wd4324 /wd4701 /wd4804 /wd4702 /wd4458 /wd4703") + if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4805") + endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -include algorithm") + set(DISABLED_WARNINGS_FOR_TRT /wd4456) + endif() + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") + endif() + set(CXX_VERSION_DEFINED TRUE) + + # There is an issue when running "Debug build" TRT EP with "Release build" TRT builtin parser on Windows. + # We enforce following workaround for now until the real fix. + if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + set(onnxruntime_USE_TENSORRT_BUILTIN_PARSER OFF) + MESSAGE(STATUS "[Note] There is an issue when running \"Debug build\" TRT EP with \"Release build\" TRT built-in parser on Windows. This build will use tensorrt oss parser instead.") + endif() + + if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) + # Add TensorRT library + find_path(TENSORRT_INCLUDE_DIR NvInfer.h + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES include) + MESSAGE(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") + find_library(TENSORRT_LIBRARY_INFER nvinfer + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES lib lib64 lib/x64) + find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES lib lib64 lib/x64) + find_library(TENSORRT_LIBRARY_NVONNXPARSER nvonnxparser + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES lib lib64 lib/x64) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") + else() + FetchContent_Declare( + onnx_tensorrt + URL ${DEP_URL_onnx_tensorrt} + URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} + ) + if (NOT CUDA_INCLUDE_DIR) + set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + endif() + # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses + # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. + onnxruntime_fetchcontent_makeavailable(onnx_tensorrt) + include_directories(${onnx_tensorrt_SOURCE_DIR}) + set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + endif() + if (WIN32) + set(CMAKE_CUDA_FLAGS ${OLD_CMAKE_CUDA_FLAGS}) + unset(PROTOBUF_LIBRARY) + unset(OLD_CMAKE_CXX_FLAGS) + unset(OLD_CMAKE_CUDA_FLAGS) + set_target_properties(nvonnxparser PROPERTIES LINK_FLAGS "/ignore:4199") + target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100) + target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100) + endif() + set(onnxparser_link_libs nvonnxparser_static) + endif() + + include_directories(${TENSORRT_INCLUDE_DIR}) + # ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static. + # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt + # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 + set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + + file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_tensorrt ${onnxruntime_providers_tensorrt_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + else() + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + endif() + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + if(onnxruntime_CUDNN_HOME) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) + endif() + + # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found + set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) + set_target_properties(onnxruntime_providers_tensorrt PROPERTIES FOLDER "ONNXRuntime") + target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_options(onnxruntime_providers_tensorrt PRIVATE ${DISABLED_WARNINGS_FOR_TRT}) + if (WIN32) + target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) + endif() + + # Needed for the provider interface, as it includes training headers when training is enabled + if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module) + endif() + endif() + + if(APPLE) + set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst") + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp) + elseif(UNIX) + set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp stdc++fs) + elseif(WIN32) + set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def") + else() + message(FATAL_ERROR "onnxruntime_providers_tensorrt unknown platform, need to specify shared library exports for it") + endif() + + install(TARGETS onnxruntime_providers_tensorrt + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_tvm.cmake b/cmake/onnxruntime_providers_tvm.cmake new file mode 100644 index 0000000000000..8fd50c70dd5d7 --- /dev/null +++ b/cmake/onnxruntime_providers_tvm.cmake @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_definitions(-DUSE_TVM=1) + if (onnxruntime_TVM_USE_HASH) + add_definitions(-DUSE_TVM_HASH=1) + endif() + + if (onnxruntime_TVM_USE_HASH) + file (GLOB_RECURSE onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" + ) + else() + file (GLOB onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" + ) + endif() + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tvm_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_tvm ${onnxruntime_providers_tvm_cc_srcs}) + + if ( CMAKE_COMPILER_IS_GNUCC ) + target_compile_options(onnxruntime_providers_tvm PRIVATE -Wno-unused-parameter -Wno-missing-field-initializers) + endif() + + target_include_directories(onnxruntime_providers_tvm PRIVATE + ${TVM_INCLUDES} + ${PYTHON_INCLUDE_DIRS}) + onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) + + add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) + + if (onnxruntime_TVM_USE_HASH) + add_dependencies(onnxruntime_providers_tvm ippcp_s) + target_include_directories(onnxruntime_providers_tvm PRIVATE ${IPP_CRYPTO_INCLUDE_DIR}) + target_link_libraries(onnxruntime_providers_tvm PRIVATE ippcp_s) + endif() + + set_target_properties(onnxruntime_providers_tvm PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_tvm PROPERTIES LINKER_LANGUAGE CXX) + + if (WIN32 AND MSVC) + # wd4100: identifier' : unreferenced formal parameter + # wd4127: conditional expression is constant + # wd4244: conversion from 'int' to 'char', possible loss of data + # TODO: 4244 should not be disabled + target_compile_options(onnxruntime_providers_tvm PRIVATE "/wd4100" "/wd4127" "/wd4244") + else() + target_compile_options(onnxruntime_providers_tvm PRIVATE "-Wno-error=type-limits") + endif() + target_compile_definitions(onnxruntime_providers_tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) + + install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tvm/tvm_provider_factory.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_tvm + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake new file mode 100644 index 0000000000000..7ac4a82c89a76 --- /dev/null +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if ("${GIT_COMMIT_ID}" STREQUAL "") + execute_process( + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMAND git rev-parse HEAD + OUTPUT_VARIABLE GIT_COMMIT_ID + OUTPUT_STRIP_TRAILING_WHITESPACE) + endif() + configure_file(${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/version_info.hpp.in ${CMAKE_CURRENT_BINARY_DIR}/VitisAI/version_info.h) + file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + ) + list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) + onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) + onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) + target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") + target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) + target_compile_definitions(onnxruntime_vitisai_ep + PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") + if(NOT MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) + endif(NOT MSVC) + + target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) + if(MSVC) + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/Zc:__cplusplus") + # for dll interface warning. + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") + # for unused formal parameter + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) + endif(MSVC) + + set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_webnn.cmake b/cmake/onnxruntime_providers_webnn.cmake new file mode 100644 index 0000000000000..05c63c22244db --- /dev/null +++ b/cmake/onnxruntime_providers_webnn.cmake @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "WebNN EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + add_compile_definitions(USE_WEBNN=1) + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) + endif() + file(GLOB_RECURSE onnxruntime_providers_webnn_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/webnn/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_webnn ${onnxruntime_providers_webnn_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_webnn onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) + + add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake new file mode 100644 index 0000000000000..9c00703ca0846 --- /dev/null +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + add_compile_definitions(USE_XNNPACK=1) + + file(GLOB_RECURSE onnxruntime_providers_xnnpack_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" + # utils for handling QDQ models + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" + ) + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool + flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + + add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") + + set_target_properties(onnxruntime_providers_xnnpack PROPERTIES LINKER_LANGUAGE CXX) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_xnnpack + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() + + # TODO fix shorten-64-to-32 warnings + # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' + if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32) + endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index bf9adbaefabcc..345ef2b504aa4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -387,6 +384,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*" ) + file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -416,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -440,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -553,6 +552,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -574,9 +574,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -708,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() @@ -741,14 +741,12 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ @@ -794,6 +792,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_graph_optimizers_srcs} + $/onnxruntime/training/ortmodule/graph_optimizers/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cf71b6bcf7c7d..980bd59b22c3f 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,19 +48,24 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" - "math/complex_mul.cc" - "math/complex_mul.h" - "math/complex_mul_impl.cu" - "math/complex_mul_impl.h" - "math/cufft_plan_cache.h" - "math/fft_ops.cc" - "math/fft_ops.h" - "math/fft_ops_impl.cu" - "math/fft_ops_impl.h" + "math/gemm_float8.cc" + "math/gemm_float8.cu" + "math/gemm_float8.h" + "moe/*" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" "quantization/attention_quantization_impl.cuh" + "quantization/dequantize_blockwise.cuh" + "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" + "quantization/matmul_nbits.cc" + "quantization/matmul_nbits.cuh" + "quantization/matmul_nbits.cu" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" @@ -86,38 +91,37 @@ set(contrib_ops_excluded_files "quantization/qordered_ops/qordered_unary_ops.cc" "quantization/qordered_ops/qordered_unary_ops_impl.h" "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "tensor/crop.cc" - "tensor/crop.h" - "tensor/crop_impl.cu" - "tensor/crop_impl.h" - "tensor/dynamicslice.cc" - "tensor/image_scaler.cc" - "tensor/image_scaler.h" - "tensor/image_scaler_impl.cu" - "tensor/image_scaler_impl.h" - "transformers/greedy_search.cc" - "transformers/greedy_search.h" - "conv_transpose_with_dynamic_pads.cc" - "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") endif() if (NOT onnxruntime_USE_NCCL) + # Those are string patterns to exclude. Do NOT use stars such as + # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") + list(APPEND contrib_ops_excluded_files "collective/sharding.cc") + list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc") endif() set(provider_excluded_files "atomic/common.cuh" - "controlflow/loop.cc" - "controlflow/loop.h" - "controlflow/scan.cc" - "controlflow/scan.h" "cu_inc/common.cuh" "math/einsum_utils/einsum_auxiliary_ops.cc" "math/einsum_utils/einsum_auxiliary_ops.h" @@ -165,7 +169,6 @@ set(provider_excluded_files "cuda_memory_check.h" "cuda_fence.cc" "cuda_fence.h" - "cuda_fwd.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" @@ -185,6 +188,8 @@ set(provider_excluded_files "gpu_data_transfer.h" "integer_gemm.cc" "tunable/*" + "cuda_nhwc_kernels.cc" + "cuda_nhwc_kernels.h" ) set(training_ops_excluded_files diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec83eb2095071..df62199dc2b42 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -35,13 +35,17 @@ function(AddTest) if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous - target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") endif() if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") - #Abseil has a lot of C4127/C4324 warnings. - target_compile_options(${_UT_TARGET} PRIVATE "/wd4127") - target_compile_options(${_UT_TARGET} PRIVATE "/wd4324") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" + "$<$>:/wd6330>") + #Abseil has a lot of C4127/C4324 warnings. + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" + "$<$>:/wd4127>") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" + "$<$>:/wd4324>") endif() set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -60,6 +64,11 @@ function(AddTest) Threads::Threads) target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_ONNXRUNTIME_DLL) else() + if(onnxruntime_USE_CUDA) + #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, + # otherwise it will impact when CUDA DLLs can be unloaded. + target_link_libraries(${_UT_TARGET} PRIVATE cudart) + endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -85,29 +94,22 @@ function(AddTest) # include dbghelp in case tests throw an ORT exception, as that exception includes a stacktrace, which requires dbghelp. target_link_libraries(${_UT_TARGET} PRIVATE debug dbghelp) - if (onnxruntime_USE_CUDA) - # disable a warning from the CUDA headers about unreferenced local functions - if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd4505>" - "$<$>:/wd4505>") - endif() - endif() if (MSVC) # warning C6326: Potential comparison of a constant with another constant. # Lot of such things came from gtest - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") # Raw new and delete. A lot of such things came from googletest. - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") # "Global initializer calls a non-constexpr function." - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") endif() @@ -199,8 +201,18 @@ function(AddTest) list(APPEND TEST_NODE_FLAGS "--experimental-wasm-simd") endif() + # prefer Node from emsdk so the version is more deterministic + if (DEFINED ENV{EMSDK_NODE}) + set(NODE_EXECUTABLE $ENV{EMSDK_NODE}) + else() + # warning as we don't know what node version is being used and whether things like the TEST_NODE_FLAGS + # will be valid. e.g. "--experimental-wasm-simd" is not valid with node v20 or later. + message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.") + set(NODE_EXECUTABLE node) + endif() + add_test(NAME ${_UT_TARGET} - COMMAND node ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} + COMMAND ${NODE_EXECUTABLE} ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} WORKING_DIRECTORY $ ) endif() @@ -372,6 +384,13 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R "${TEST_SRC_DIR}/providers/cuda/*" ) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src}) + + if (onnxruntime_USE_CUDA_NHWC_OPS) + file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc" + ) + list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_nhwc_src}) + endif() endif() if (onnxruntime_USE_CANN) @@ -698,7 +717,7 @@ onnxruntime_add_static_library(onnxruntime_test_utils ${onnxruntime_test_utils_s if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) @@ -755,13 +774,8 @@ set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) -if(NOT TARGET onnxruntime AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) -endif() -if (onnxruntime_USE_CUDA) - onnxruntime_add_static_library(onnxruntime_test_cuda_ops_lib ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) - list(APPEND onnxruntime_test_common_libs onnxruntime_test_cuda_ops_lib) +if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" ) @@ -769,7 +783,7 @@ if (onnxruntime_USE_CUDA) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() @@ -822,7 +836,9 @@ if (onnxruntime_USE_TENSORRT) # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu_* or *xxx_*. list(APPEND test_all_args "--gtest_filter=-*cpu_*:*cuda_*" ) endif () - +if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) +endif() AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} @@ -832,11 +848,15 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. # However, this is test code. We are less concerned. - target_compile_options(onnxruntime_test_all PRIVATE "/wd26451" "/wd4244") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd26451>" + "$<$>:/wd26451>") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -848,7 +868,7 @@ if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() @@ -886,7 +906,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() @@ -972,9 +992,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.so" "${onnxruntime_QNN_HOME}/target/${QNN_ARCH_ABI}/lib/*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.dll" "${onnxruntime_QNN_HOME}/target/${QNN_ARCH_ABI}/lib/*.dll") + file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.dll") if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc") - file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/target/hexagon-v68/lib/unsigned/libQnnHtpV68Skel.so") + file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so") list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) endif() message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) @@ -1092,18 +1112,18 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26400>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26400>" "$<$>:/wd26400>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26814>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26497>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1255,7 +1275,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_test_cuda_ops_lib cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1270,7 +1290,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - + if (onnxruntime_USE_CUDA) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1288,7 +1311,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_shared_lib_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() @@ -1356,13 +1379,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") @@ -1476,7 +1499,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) if (HAS_QSPECTRE) - list(APPEND custom_op_lib_option "$<$:SHELL:-Xcompiler /Qspectre>") + list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() endif() @@ -1503,7 +1526,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") if (NOT onnxruntime_USE_CUDA) - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(custom_op_library PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") endif() endif() @@ -1577,7 +1600,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_customopregistration_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c6510c97a617e..9014089cb6112 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -192,8 +192,13 @@ else() onnxruntime_util re2::re2 ) + + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) + string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ALLOW_TABLE_GROWTH=1") endif() if(onnxruntime_USE_WEBNN) @@ -204,7 +209,6 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() - set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']") if (onnxruntime_USE_JSEP) set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") else() @@ -212,7 +216,7 @@ else() endif() target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}" + "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" "SHELL:-s MAXIMUM_MEMORY=4294967296" "SHELL:-s EXIT_RUNTIME=0" diff --git a/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch b/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch new file mode 100644 index 0000000000000..0a864cdc019b4 --- /dev/null +++ b/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch @@ -0,0 +1,17 @@ +--- absl/container/internal/layout.h 2023-11-28 09:35:48 ++++ absl/container/internal/layout.updated.h 2023-11-28 10:13:14 +@@ -181,9 +181,11 @@ + #include + #endif + +-#if defined(__GXX_RTTI) +-#define ABSL_INTERNAL_HAS_CXA_DEMANGLE +-#endif ++// Comment out ABSL_INTERNAL_HAS_CXA_DEMANGLE definition to work around this issue: ++// https://github.com/abseil/abseil-cpp/issues/1435 ++// #if defined(__GXX_RTTI) ++// #define ABSL_INTERNAL_HAS_CXA_DEMANGLE ++// #endif + + #ifdef ABSL_INTERNAL_HAS_CXA_DEMANGLE + #include diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index d564ffba914fe..15844dd917744 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,17 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 514b98fde..59c8a568a 100644 +index 04674124c..12e8b8b00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -1,7 +1,7 @@ - cmake_minimum_required(VERSION 3.14) +@@ -19,7 +19,7 @@ endif() + set(version 1.1.0) # Check support for CUDA/HIP in Cmake --project(composable_kernel) -+project(composable_kernel LANGUAGES CXX HIP) +-project(composable_kernel VERSION ${version}) ++project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") -@@ -94,27 +94,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) +@@ -173,27 +173,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") @@ -39,7 +39,7 @@ index 514b98fde..59c8a568a 100644 ## HIP find_package(HIP REQUIRED) # Override HIP version in config.h, if necessary. -@@ -136,8 +115,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) +@@ -215,8 +194,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") endif() message(STATUS "Build with HIP ${HIP_VERSION}") @@ -48,7 +48,18 @@ index 514b98fde..59c8a568a 100644 ## tidy include(EnableCompilerWarnings) -@@ -391,11 +368,3 @@ rocm_install(FILES +@@ -376,7 +353,9 @@ if(BUILD_DEV) + add_compile_options(-Werror -Weverything) + endif() + #add flags to reduce the size of binaries +-add_compile_options(-Oz -flto=thin) ++# -flto requires ORT to use a linker that support LTO and -flto flag shoud be passed to linker together. ++# add_compile_options(-Oz -flto=thin) ++add_compile_options(-Oz) + message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +@@ -482,11 +461,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -61,20 +72,21 @@ index 514b98fde..59c8a568a 100644 - HEADER_ONLY -) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index 1d54a141b..4edd7dbfb 100644 +index 9cb5d0e9a..141a46f3d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -@@ -1,7 +1,13 @@ - function(add_instance_library INSTANCE_NAME) - message("adding instance ${INSTANCE_NAME}") -+ set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) -+ # Always disable debug symbol and C debug assert due to -+ # - Linker error: ... relocation truncated to fit ..., caused by object files to be linked are too huge. -+ # - https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/622 -+ target_compile_options(${INSTANCE_NAME} PRIVATE -g0 -DNDEBUG) - target_compile_features(${INSTANCE_NAME} PUBLIC) -+ target_compile_definitions(${INSTANCE_NAME} PRIVATE "__HIP_PLATFORM_AMD__=1" "__HIP_PLATFORM_HCC__=1") - set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - clang_tidy_check(${INSTANCE_NAME}) - endfunction(add_instance_library INSTANCE_NAME) +@@ -44,8 +44,14 @@ function(add_instance_library INSTANCE_NAME) + endforeach() + #only continue if there are some source files left on the list + if(ARGN) ++ set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) ++ # Always disable debug symbol and C debug assert due to ++ # - Linker error: ... relocation truncated to fit ..., caused by object files to be linked are too huge. ++ # - https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/622 ++ target_compile_options(${INSTANCE_NAME} PRIVATE -g0 -DNDEBUG) + target_compile_features(${INSTANCE_NAME} PUBLIC) ++ target_compile_definitions(${INSTANCE_NAME} PRIVATE "__HIP_PLATFORM_AMD__=1" "__HIP_PLATFORM_HCC__=1") + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + clang_tidy_check(${INSTANCE_NAME}) + set(result 0) diff --git a/cmake/patches/cutlass/cutlass.patch b/cmake/patches/cutlass/cutlass.patch deleted file mode 100644 index bda1de8b46916..0000000000000 --- a/cmake/patches/cutlass/cutlass.patch +++ /dev/null @@ -1,92 +0,0 @@ -diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp -index 3790ebd3..cf727d09 100644 ---- a/include/cute/numeric/complex.hpp -+++ b/include/cute/numeric/complex.hpp -@@ -41,10 +41,14 @@ - // With CUDA 11.4, builds show spurious "-Wconversion" warnings - // on line 656 of thrust/detail/type_traits.h. - // These pragmas suppress the warnings. -+#ifdef __GNUC__ - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wconversion" -+#endif - #include -+#ifdef __GNUC__ - #pragma GCC diagnostic pop -+#endif - - #include - -diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h -index 59aec46a..8f2a913a 100644 ---- a/include/cutlass/functional.h -+++ b/include/cutlass/functional.h -@@ -89,7 +89,7 @@ struct multiplies { - } - }; - --#if defined(__CUDA_ARCH__) -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - /// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set - template<> - struct plus<__half2> { -@@ -143,12 +143,12 @@ struct multiplies<__half> { - - - // Maximum with nan propogation --// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN - template - struct maximum_with_nan_propogation { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { -- return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ return lhs > rhs or isnan(lhs) ? lhs : rhs; - } - }; - -@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); - #else -- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ res = lhs > rhs or isnan(lhs) ? lhs : rhs; - #endif - return res; - } -@@ -233,7 +233,7 @@ struct negate { - } - }; - --/// Greater equal -+/// Greater equal - template - struct greater_equal { - CUTLASS_HOST_DEVICE -@@ -242,7 +242,7 @@ struct greater_equal { - } - }; - --/// Greater -+/// Greater - template - struct greater { - CUTLASS_HOST_DEVICE -@@ -251,7 +251,7 @@ struct greater { - } - }; - --/// Less equal -+/// Less equal - template - struct less_equal { - CUTLASS_HOST_DEVICE -@@ -260,7 +260,7 @@ struct less_equal { - } - }; - --/// Less -+/// Less - template - struct less { - CUTLASS_HOST_DEVICE diff --git a/cmake/patches/eigen/Fix_Eigen_Build_Break.patch b/cmake/patches/eigen/Fix_Eigen_Build_Break.patch deleted file mode 100644 index ca0e0fd23ddee..0000000000000 --- a/cmake/patches/eigen/Fix_Eigen_Build_Break.patch +++ /dev/null @@ -1,27 +0,0 @@ -diff -Naur git_org/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h git/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h ---- git_org/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-07-17 15:27:59.540667336 -0500 -+++ git/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-07-17 15:30:16.000000000 -0500 -@@ -1076,8 +1076,9 @@ - dest = *b; - } - -- EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const -- {} -+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const -+ { -+ } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const - { -@@ -1145,8 +1146,9 @@ - loadRhs(b,dest); - } - -- EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const -- {} -+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const -+ { -+ } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const - { diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 155d153019f85..a2d7672a3d48d 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -64,16 +64,3 @@ index 0aab3e26..0f859267 100644 +#endif + #endif // ! ONNX_ONNX_PB_H -diff --git a/onnx/checker.cc b/onnx/checker.cc -index 8fdaf037..1beb1b88 100644 ---- a/onnx/checker.cc -+++ b/onnx/checker.cc -@@ -190,7 +190,7 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) { - } - std::string data_path = path_join(ctx.get_model_dir(), relative_path); - // use stat64 to check whether the file exists --#ifdef __APPLE__ -+#if defined(__APPLE__) || defined(__wasm__) - struct stat buffer; // APPLE does not have stat64 - if (stat((data_path).c_str(), &buffer) != 0) { - #else diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 37bdbf9fb53f6..736fffb1e384c 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,66 +1,39 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d53c48aa1..77c3cf983 100755 +index dba9b4687..a4345898d 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -105,22 +105,12 @@ ENDIF() - +@@ -122,7 +122,7 @@ ENDIF() + # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") --ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS)$") -+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS|Emscripten|iOS)$") - MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}") +-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT)$") ++ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT|Emscripten|iOS)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() - - # ---[ Download deps - IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- IF(NOT DEFINED CLOG_SOURCE_DIR) -- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)") -- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory") -- ENDIF() -- - IF(NOT DEFINED CPUINFO_SOURCE_DIR) - MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") -@@ -7108,6 +7098,10 @@ IF(MSVC) - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O1 >") -+ELSEIF(CMAKE_GENERATOR STREQUAL Xcode) -+ TARGET_COMPILE_OPTIONS(all_microkernels PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -Os >) - ELSE() - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") -@@ -7142,26 +7136,6 @@ IF(LIBM) - TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM}) + IF(CMAKE_SYSTEM_NAME MATCHES "Windows") +@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging) + TARGET_LINK_LIBRARIES(post-operation PRIVATE logging) + TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ++ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph) ++ ELSE() ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ ENDIF() + SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() - --# ---[ Configure clog --IF(NOT TARGET clog) -- IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") -- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") -- ADD_SUBDIRECTORY( -- "${CLOG_SOURCE_DIR}/deps/clog" -- "${CMAKE_BINARY_DIR}/clog") -- # We build static version of clog but a dynamic library may indirectly depend on it -- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) -- ELSE() -- ADD_LIBRARY(clog STATIC IMPORTED) -- FIND_LIBRARY(CLOG_LIBRARY clog) -- IF(NOT CLOG_LIBRARY) -- MESSAGE(FATAL_ERROR "Cannot find clog") -- ENDIF() -- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}") -- ENDIF() --ENDIF() -- - # ---[ Configure cpuinfo - IF(NOT TARGET cpuinfo) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) + IF(NOT MSVC) +@@ -543,8 +548,9 @@ ENDIF() + IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") + SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") + SET_PROPERTY(SOURCE ${PROD_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") +- SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") +- SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") ++ # set this to armv7-a to workaround build issue. we don't target armv6 so it shouldn't matter ++ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") ++ SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") + SET_PROPERTY(SOURCE ${ALL_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") + SET_PROPERTY(SOURCE ${PROD_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") + SET_PROPERTY(SOURCE ${ALL_NEONFP16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-fp16 ") diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 395996f0fa4b9..268ee3960e75a 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -451,6 +451,8 @@ onnxruntime_add_static_library(winml_lib_api ${winml_lib_api_dir}/impl/TensorKindFrom.h ${winml_lib_api_dir}/impl/TensorMemoryBufferReference.h ${winml_lib_api_dir}/NumericData.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.h ${winml_lib_api_dir}/ImageFeatureDescriptor.cpp ${winml_lib_api_dir}/ImageFeatureDescriptor.h ${winml_lib_api_dir}/ImageFeatureValue.cpp diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 0288d752d8749..5e43756ced7b1 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -17,9 +17,13 @@ CMake creates a target to this project x64 false - false + true + true None + + true + .. ..\tools\nuget\generate_nuspec_for_native_nuget.py @@ -30,13 +34,15 @@ CMake creates a target to this project ..\build\Linux $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration) + python3 - + ..\build\Windows $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + python @@ -86,28 +92,55 @@ CMake creates a target to this project - + + + + + + + + Properties="NoBuild=true;Platform=AnyCPU;PackageVersion=$(PackageVersion);OrtPackageId=$(OrtPackageId);IncludeMobileTargets=$(IncludeMobileTargets)"/> - - + + + - - + + + - + + + + - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 29ccf55f081d5..0c74a23204d4f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -4,66 +4,53 @@ Microsoft.ML.OnnxRuntime - - PreNet6 - netstandard2.0;netcoreapp3.1;net6.0 + true + netstandard2.0 + - - - xamarinios10;monoandroid11.0 + + + false - - monoandroid11.0 - + + NOTE: We include in a build of the managed package when creating Microsoft.ML.OnnxRuntime.Gpu as both + the CPU and GPU packaging pipelines can publish Microsoft.ML.OnnxRuntime.Managed, and we need the targets + to be consistent in both. + --> - net6.0;net6.0-android;net6.0-ios;net6.0-macos + '$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Gpu') AND + '$(IncludeMobileTargets)' == 'true' AND + Exists('$(MSBuildExtensionsPath)\Xamarin\Android') AND + Exists('$(MSBuildExtensionsPath)\Xamarin\iOS')"> + xamarinios10;monoandroid11.0 - - net6.0;net6.0-android + + monoandroid11.0 - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining) + + + $(MobileTargets);net6.0-android;net6.0-ios - - $(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(MobileTargets);net6.0-android - - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining);$(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(BaseTargets);$(MobileTargets) - AnyCPU;x86 default @@ -204,8 +191,9 @@ $(DefineConstants);$(OrtConstants) - + @@ -214,7 +202,6 @@ - --> /// Creates native OrtTensorRTProviderOptions instance /// /// (output) native instance of OrtTensorRTProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateTensorRTProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateTensorRTProviderOptions( out IntPtr /*(OrtTensorRTProviderOptions**)*/ trtProviderOptionsInstance); public static DOrtCreateTensorRTProviderOptions OrtCreateTensorRTProviderOptions; @@ -610,7 +609,7 @@ internal class NativeLib /// configuration values of OrtTensorRTProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateTensorRTProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateTensorRTProviderOptions( IntPtr /*(OrtTensorRTProviderOptions*)*/ trtProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -623,10 +622,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetTensorRTProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetTensorRTProviderOptionsAsString( IntPtr /*(OrtTensorRTProviderOptionsV2**)*/ trtProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetTensorRTProviderOptionsAsString OrtGetTensorRTProviderOptionsAsString; /// @@ -642,7 +641,7 @@ internal class NativeLib /// /// (output) native instance of OrtCUDAProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateCUDAProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateCUDAProviderOptions( out IntPtr /*(OrtCUDAProviderOptions**)*/ cudaProviderOptionsInstance); public static DOrtCreateCUDAProviderOptions OrtCreateCUDAProviderOptions; @@ -654,7 +653,7 @@ internal class NativeLib /// configuration values of OrtCUDAProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateCUDAProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateCUDAProviderOptions( IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -667,10 +666,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetCUDAProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetCUDAProviderOptionsAsString( IntPtr /*(OrtCUDAProviderOptionsV2**)*/ cudaProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetCUDAProviderOptionsAsString OrtGetCUDAProviderOptionsAsString; /// @@ -686,7 +685,7 @@ internal class NativeLib /// /// (output) native instance of OrtROCMProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateROCMProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateROCMProviderOptions( out IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance); public static DOrtCreateROCMProviderOptions OrtCreateROCMProviderOptions; @@ -698,7 +697,7 @@ internal class NativeLib /// configuration values of OrtROCMProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateROCMProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateROCMProviderOptions( IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -711,10 +710,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetROCMProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetROCMProviderOptionsAsString( IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetROCMProviderOptionsAsString OrtGetROCMProviderOptionsAsString; /// @@ -725,34 +724,34 @@ internal class NativeLib public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance); public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions; - #endregion +#endregion - #region Status API +#region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/status); + public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); public static DOrtGetErrorCode OrtGetErrorCode; // returns char*, need to convert to string by the caller. // does not free the underlying OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* char* */DOrtGetErrorMessage(IntPtr /* (OrtStatus*) */status); + public delegate IntPtr /* char* */ DOrtGetErrorMessage(IntPtr /* (OrtStatus*) */ status); public static DOrtGetErrorMessage OrtGetErrorMessage; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseStatus(IntPtr /*(OrtStatus*)*/ statusPtr); public static DOrtReleaseStatus OrtReleaseStatus; - #endregion Status API +#endregion Status API - #region InferenceSession API +#region InferenceSession API [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSession( - IntPtr /* (OrtEnv*) */ environment, - //[MarshalAs(UnmanagedType.LPStr)]string modelPath - byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */sessopnOptions, - out IntPtr /**/ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSession( + IntPtr /* (OrtEnv*) */ environment, + //[MarshalAs(UnmanagedType.LPStr)]string modelPath + byte[] modelPath, + IntPtr /* (OrtSessionOptions*) */ sessopnOptions, + out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; @@ -765,22 +764,22 @@ internal class NativeLib /// Native OrtPrepackedWeightsContainer instance /// (Output) Created native OrtSession instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionWithPrepackedWeightsContainer( - IntPtr /* (OrtEnv*) */ environment, - byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */sessionOptions, - IntPtr /* (OrtPrepackedWeightsContainer*) */prepackedWeightsContainer, - out IntPtr /* (OrtSession**) */ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionWithPrepackedWeightsContainer( + IntPtr /* (OrtEnv*) */ environment, + byte[] modelPath, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + IntPtr /* (OrtPrepackedWeightsContainer*) */ prepackedWeightsContainer, + out IntPtr /* (OrtSession**) */ session); public static DOrtCreateSessionWithPrepackedWeightsContainer OrtCreateSessionWithPrepackedWeightsContainer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray( - IntPtr /* (OrtEnv*) */ environment, - byte[] modelData, - UIntPtr modelSize, - IntPtr /* (OrtSessionOptions*) */ sessionOptions, - out IntPtr /**/ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionFromArray( + IntPtr /* (OrtEnv*) */ environment, + byte[] modelData, + UIntPtr modelSize, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + out IntPtr /**/ session); public static DOrtCreateSessionFromArray OrtCreateSessionFromArray; /// @@ -793,169 +792,167 @@ internal class NativeLib /// Native OrtPrepackedWeightsContainer instance /// (Output) Created native OrtSession instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArrayWithPrepackedWeightsContainer( - IntPtr /* (OrtEnv*) */ environment, - byte[] /* (void*) */ modelData, - UIntPtr /* (size_t) */ modelSize, - IntPtr /* (OrtSessionOptions*) */ sessionOptions, - IntPtr /* (OrtPrepackedWeightsContainer*) */prepackedWeightsContainer, - out IntPtr /* (OrtSession**) */ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionFromArrayWithPrepackedWeightsContainer( + IntPtr /* (OrtEnv*) */ environment, + byte[] /* (void*) */ modelData, + UIntPtr /* (size_t) */ modelSize, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + IntPtr /* (OrtPrepackedWeightsContainer*) */ prepackedWeightsContainer, + out IntPtr /* (OrtSession**) */ session); public static DOrtCreateSessionFromArrayWithPrepackedWeightsContainer OrtCreateSessionFromArrayWithPrepackedWeightsContainer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRun( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options - IntPtr[] inputNames, - IntPtr[] /* (OrtValue*[])*/ inputValues, - UIntPtr inputCount, - IntPtr[] outputNames, - UIntPtr outputCount, - IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */ - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options + IntPtr[] inputNames, + IntPtr[] /* (OrtValue*[])*/ inputValues, + UIntPtr inputCount, + IntPtr[] outputNames, + UIntPtr outputCount, + IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */ + ); public static DOrtRun OrtRun; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunWithBinding( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can not be null - IntPtr /*(const OrtIoBinding*)*/ io_binding - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can not be null + IntPtr /*(const OrtIoBinding*)*/ io_binding); public static DOrtRunWithBinding OrtRunWithBinding; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetInputCount OrtSessionGetInputCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetOutputCount OrtSessionGetOutputCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetOverridableInitializerCount OrtSessionGetOverridableInitializerCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetInputName OrtSessionGetInputName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOutputName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetOutputName OrtSessionGetOutputName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionEndProfiling( - IntPtr /*(const OrtSession*)*/ session, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/profile_file); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionEndProfiling( + IntPtr /*(const OrtSession*)*/ session, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ profile_file); public static DOrtSessionEndProfiling OrtSessionEndProfiling; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetOverridableInitializerName OrtSessionGetOverridableInitializerName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetInputTypeInfo OrtSessionGetInputTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOutputTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetOutputTypeInfo OrtSessionGetOutputTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo; // release the typeinfo using OrtReleaseTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session); + public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/ session); public static DOrtReleaseTypeInfo OrtReleaseTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseSession(IntPtr /*(OrtSession*)*/session); + public delegate void DOrtReleaseSession(IntPtr /*(OrtSession*)*/ session); public static DOrtReleaseSession OrtReleaseSession; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetProfilingStartTimeNs( - IntPtr /*(const OrtSession*)*/ session, - out UIntPtr /*(ulong* out)*/ startTime); + IntPtr /*(const OrtSession*)*/ session, + out UIntPtr /*(ulong* out)*/ startTime); public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2( - IntPtr /* (OrtEnv*) */ environment, - IntPtr /*(char*)*/ provider_type, - IntPtr /*(OrtMemoryInfo*)*/ mem_info, - IntPtr /*(OrtArenaCfg*)*/ arena_cfg, - IntPtr /*(char**)*/ provider_options_keys, - IntPtr /*(char**)*/ provider_options_values, - UIntPtr /*(size_t)*/num_keys); + IntPtr /* (OrtEnv*) */ environment, + IntPtr /*(char*)*/ provider_type, + IntPtr /*(OrtMemoryInfo*)*/ mem_info, + IntPtr /*(OrtArenaCfg*)*/ arena_cfg, + IntPtr /*(char**)*/ provider_options_keys, + IntPtr /*(char**)*/ provider_options_values, + UIntPtr /*(size_t)*/ num_keys); public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options - IntPtr[] /*(char**)*/ inputNames, - IntPtr[] /*(OrtValue*[])*/ inputValues, - UIntPtr /*(size_t)*/ inputCount, - IntPtr[] /*(char**)*/ outputNames, - UIntPtr /*(size_t)*/ outputCount, - IntPtr[] /*(OrtValue*[])*/ outputValues, - IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function - IntPtr /*(void*)*/ user_data - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options + IntPtr[] /*(char**)*/ inputNames, + IntPtr[] /*(OrtValue*[])*/ inputValues, + UIntPtr /*(size_t)*/ inputCount, + IntPtr[] /*(char**)*/ outputNames, + UIntPtr /*(size_t)*/ outputCount, + IntPtr[] /*(OrtValue*[])*/ outputValues, + IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function + IntPtr /*(void*)*/ user_data); public static DOrtRunAsync OrtRunAsync; - #endregion InferenceSession API +#endregion InferenceSession API - #region SessionOptions API +#region SessionOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions); public static DOrtCreateSessionOptions OrtCreateSessionOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session); + public delegate void DOrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/ session); public static DOrtReleaseSessionOptions OrtReleaseSessionOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -964,7 +961,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionExecutionMode(IntPtr /*(OrtSessionOptions*)*/ options, - ExecutionMode execution_mode); + ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -996,7 +993,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */logId); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -1027,7 +1024,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Config value [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, - byte[] /* const char* */configKey, + byte[] /* const char* */ configKey, byte[] /* const char* */ configValue); public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry; @@ -1090,9 +1087,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtTensorRTProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_TensorRT( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtTensorRTProviderOptions*)*/ trtProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_TensorRT( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtTensorRTProviderOptions*)*/ trtProviderOptions); public static DSessionOptionsAppendExecutionProvider_TensorRT SessionOptionsAppendExecutionProvider_TensorRT; @@ -1102,9 +1099,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtTensorRTProviderOptionsV2 instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_TensorRT_V2( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtTensorRTProviderOptionsV2*)*/ trtProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_TensorRT_V2( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtTensorRTProviderOptionsV2*)*/ trtProviderOptions); public static DSessionOptionsAppendExecutionProvider_TensorRT_V2 SessionOptionsAppendExecutionProvider_TensorRT_V2; @@ -1114,9 +1111,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtCUDAProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_CUDA( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions); public static DSessionOptionsAppendExecutionProvider_CUDA SessionOptionsAppendExecutionProvider_CUDA; @@ -1126,9 +1123,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtCUDAProviderOptionsV2 instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA_V2( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtCUDAProviderOptionsV2*)*/ cudaProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_CUDA_V2( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtCUDAProviderOptionsV2*)*/ cudaProviderOptions); public static DSessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsAppendExecutionProvider_CUDA_V2; @@ -1138,9 +1135,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtROCMProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_ROCM( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_ROCM( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions); public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; @@ -1151,9 +1148,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Dimension denotation /// Dimension value [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ dimDenotation, - long dimValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ dimDenotation, + long dimValue); public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride; @@ -1164,9 +1161,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Dimension name /// Dimension value [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ dimName, - long dimValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ dimName, + long dimValue); public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName; @@ -1177,9 +1174,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Library path /// (out) Native library handle [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, - byte[] /*(const char*)*/ libraryPath, - out IntPtr /*(void**)*/ libraryHandle); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, + byte[] /*(const char*)*/ libraryPath, + out IntPtr /*(void**)*/ libraryHandle); public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary; @@ -1189,8 +1186,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native SessionOptions instance /// Library path [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary_V2(IntPtr /*(OrtSessionOptions*) */ options, - byte[] /*(const ORTCHAR_T*)*/ libraryPath); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRegisterCustomOpsLibrary_V2(IntPtr /*(OrtSessionOptions*) */ options, + byte[] /*(const ORTCHAR_T*)*/ libraryPath); public static DOrtRegisterCustomOpsLibrary_V2 OrtRegisterCustomOpsLibrary_V2; @@ -1201,9 +1198,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Name of the initializer /// Native OrtValue instnce [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ name, - IntPtr /*(OrtValue*)*/ ortValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ name, + IntPtr /*(OrtValue*)*/ ortValue); public static DOrtAddInitializer OrtAddInitializer; @@ -1220,25 +1217,25 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Configuration values to add /// Number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider( - IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ providerName, - IntPtr[] /*(const char* const *)*/ providerOptionsKeys, - IntPtr[] /*(const char* const *)*/ providerOptionsValues, - UIntPtr /*(size_t)*/ numKeys); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider( + IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ providerName, + IntPtr[] /*(const char* const *)*/ providerOptionsKeys, + IntPtr[] /*(const char* const *)*/ providerOptionsValues, + UIntPtr /*(size_t)*/ numKeys); public static DSessionOptionsAppendExecutionProvider SessionOptionsAppendExecutionProvider; - #endregion +#endregion - #region RunOptions API +#region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); public static DOrtCreateRunOptions OrtCreateRunOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/options); + public delegate void DOrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/ options); public static DOrtReleaseRunOptions OrtReleaseRunOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -1259,11 +1256,11 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunLogSeverityLevel(IntPtr /* OrtRunOptions* */ options, - out OrtLoggingLevel severityLevel); + out OrtLoggingLevel severityLevel); public static DOrtRunOptionsGetRunLogSeverityLevel OrtRunOptionsGetRunLogSeverityLevel; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */options, out IntPtr /* const char** */ runtag); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */ options, out IntPtr /* const char** */ runtag); public static DOrtRunOptionsGetRunTag OrtRunOptionsGetRunTag; // Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions @@ -1276,7 +1273,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsUnsetTerminate(IntPtr /* OrtRunOptions* */ options); public static DOrtRunOptionsUnsetTerminate OrtRunOptionsUnsetTerminate; - /// /// Add run config entry /// @@ -1285,13 +1281,13 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Config value [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddRunConfigEntry(IntPtr /* OrtRunOptions* */ options, - byte[] /* const char* */configKey, + byte[] /* const char* */ configKey, byte[] /* const char* */ configValue); public static DOrtAddRunConfigEntry OrtAddRunConfigEntry; - #endregion +#endregion - #region ThreadingOptions API +#region ThreadingOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateThreadingOptions(out IntPtr /* OrtCreateThreadingOptions** */ threadingOptions); @@ -1316,27 +1312,26 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtThreadingOptionsSetGlobalSpinControl(IntPtr /* OrtThreadingOptions* */ threadingOptions, int allowSpinning); public static DOrtThreadingOptionsSetGlobalSpinControl OrtThreadingOptionsSetGlobalSpinControl; - #endregion +#endregion - #region Allocator/MemoryInfo API +#region Allocator / MemoryInfo API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfo( - byte[] /*(const char*) */name, - OrtAllocatorType allocatorType, - int identifier, - OrtMemType memType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller - ); + byte[] /*(const char*) */ name, + OrtAllocatorType allocatorType, + int identifier, + OrtMemType memType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller + ); public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateCpuMemoryInfo( - OrtAllocatorType allocatorType, - OrtMemType memoryType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo - ); + OrtAllocatorType allocatorType, + OrtMemType memoryType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo); public static DOrtCreateCpuMemoryInfo OrtCreateCpuMemoryInfo; @@ -1347,15 +1342,15 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCompareMemoryInfo( - IntPtr /*(const OrtMemoryInfo*)*/ info1, - IntPtr /*(const OrtMemoryInfo*)*/ info2, - out int /*(int* out)*/ result); + IntPtr /*(const OrtMemoryInfo*)*/ info1, + IntPtr /*(const OrtMemoryInfo*)*/ info2, + out int /*(int* out)*/ result); public static DOrtCompareMemoryInfo OrtCompareMemoryInfo; /** - * Do not free the returned value - */ + * Do not free the returned value + */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetName(IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, out IntPtr /*(const char**)*/ name); @@ -1368,26 +1363,25 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetMemType( - IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, - out OrtMemType /*(OrtMemType*)*/ mem_type); + IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, + out OrtMemType /*(OrtMemType*)*/ mem_type); public static DOrtMemoryInfoGetMemType OrtMemoryInfoGetMemType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetType( - IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, - out OrtAllocatorType /*(OrtAllocatorType*)*/ alloc_type - ); + IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, + out OrtAllocatorType /*(OrtAllocatorType*)*/ alloc_type); public static DOrtMemoryInfoGetType OrtMemoryInfoGetType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); public static DOrtGetAllocatorWithDefaultOptions OrtGetAllocatorWithDefaultOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/info); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/ info); public static DOrtAllocatorGetInfo OrtAllocatorGetInfo; @@ -1402,8 +1396,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Pointer to a native OrtStatus instance indicating success/failure of config creation [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateArenaCfg(UIntPtr /*(size_t)*/ maxMemory, int /*(int)*/ arenaExtendStrategy, - int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk, - out IntPtr /*(OrtArenaCfg**)*/ arenaCfg); + int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk, + out IntPtr /*(OrtArenaCfg**)*/ arenaCfg); public static DOrtCreateArenaCfg OrtCreateArenaCfg; @@ -1457,9 +1451,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtAllocatorFree OrtAllocatorFree; - #endregion Allocator/MemoryInfo API +#endregion Allocator / MemoryInfo API - #region IoBinding API +#region IoBinding API /// /// Create OrtIoBinding instance that is used to bind memory that is allocated @@ -1634,7 +1628,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateAndRegisterAllocator(IntPtr /*(OrtEnv*)*/ env, IntPtr /*(const OrtMemoryInfo*)*/ memInfo, - IntPtr/*(const OrtArenaCfg*)*/ arenaCfg); + IntPtr /*(const OrtArenaCfg*)*/ arenaCfg); public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator; @@ -1644,13 +1638,13 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// the source projected language [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetLanguageProjection(IntPtr /* (OrtEnv*) */ environment, - int projection); + int projection); public static DOrtSetLanguageProjection OrtSetLanguageProjection; - #endregion IoBinding API +#endregion IoBinding API - #region ModelMetadata API +#region ModelMetadata API /// /// Gets the ModelMetadata associated with an InferenceSession @@ -1670,7 +1664,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) producer name from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName; @@ -1682,7 +1676,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) graph name from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetGraphName OrtModelMetadataGetGraphName; @@ -1694,7 +1688,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) domain from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDomain(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetDomain OrtModelMetadataGetDomain; @@ -1706,7 +1700,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) description from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription; @@ -1718,7 +1712,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) graph description from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetGraphDescription OrtModelMetadataGetGraphDescription; @@ -1742,7 +1736,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) number of keys in the custom metadata map [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetCustomMetadataMapKeys(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys); public static DOrtModelMetadataGetCustomMetadataMapKeys OrtModelMetadataGetCustomMetadataMapKeys; @@ -1755,7 +1749,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) value for the key in the custom metadata map [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataLookupCustomMetadataMap(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value); public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap; @@ -1768,9 +1762,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtReleaseModelMetadata OrtReleaseModelMetadata; - #endregion ModelMetadata API +#endregion ModelMetadata API - #region OrtValue API +#region OrtValue API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtHasValue(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(int*)*/ hasValue); @@ -1779,9 +1773,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetValue(IntPtr /*(OrtValue*)*/ value, - int index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(OrtValue**)*/ outputValue); + int index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ outputValue); public static DOrtGetValue OrtGetValue; @@ -1801,8 +1795,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtGetValueCount OrtGetValueCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr/*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, - UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, + UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); public static DOrtCreateValue OrtCreateValue; @@ -1813,23 +1807,23 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateTensorAsOrtValue( - IntPtr /*_Inout_ OrtAllocator* */ allocator, - long[] /*_In_ const int64_t* */ shape, - UIntPtr /*size_t*/ shape_len, - Tensors.TensorElementType type, - out IntPtr /* OrtValue** */ outputValue); + IntPtr /*_Inout_ OrtAllocator* */ allocator, + long[] /*_In_ const int64_t* */ shape, + UIntPtr /*size_t*/ shape_len, + Tensors.TensorElementType type, + out IntPtr /* OrtValue** */ outputValue); public static DOrtCreateTensorAsOrtValue OrtCreateTensorAsOrtValue; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus */ DOrtCreateTensorWithDataAsOrtValue( - IntPtr /* (const OrtMemoryInfo*) */ allocatorInfo, - IntPtr /* (void*) */dataBufferHandle, - UIntPtr dataLength, - long[] shape, - UIntPtr shapeLength, - Tensors.TensorElementType type, - out IntPtr /* OrtValue** */ outputValue); + IntPtr /* (const OrtMemoryInfo*) */ allocatorInfo, + IntPtr /* (void*) */ dataBufferHandle, + UIntPtr dataLength, + long[] shape, + UIntPtr shapeLength, + Tensors.TensorElementType type, + out IntPtr /* OrtValue** */ outputValue); public static DOrtCreateTensorWithDataAsOrtValue OrtCreateTensorWithDataAsOrtValue; @@ -1854,60 +1848,67 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// \param len total data length, not including the trailing '\0' chars. [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtFillStringTensor( - IntPtr /* OrtValue */ value, - IntPtr[] /* const char* const* */s, - UIntPtr /* size_t */ s_len); + IntPtr /* OrtValue */ value, + IntPtr[] /* const char* const* */ s, + UIntPtr /* size_t */ s_len); public static DOrtFillStringTensor OrtFillStringTensor; + /// \param value A tensor created from OrtCreateTensor... function. + /// \param index The index of the entry in the tensor to resize. + /// \param length_in_bytes Length to resize the string to. + /// \param buffer The resized buffer. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer( - IntPtr /* OrtValue */ value, - UIntPtr /* size_t */ index, - UIntPtr /* size_t */ length_in_bytes, - out IntPtr /* char** */ buffer - ); + IntPtr /* OrtValue */ value, + UIntPtr /* size_t */ index, + UIntPtr /* size_t */ length_in_bytes, + out IntPtr /* char** */ buffer); public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( - IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, - UIntPtr dst_buffer_len, - UIntPtr[] offsets, - UIntPtr offsets_len); + IntPtr /*(OrtValue*)*/ value, + byte[] /*(void*)*/ dst_buffer, + UIntPtr dst_buffer_len, + UIntPtr[] offsets, + UIntPtr offsets_len); public static DOrtGetStringTensorContent OrtGetStringTensorContent; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value, - out UIntPtr /*(size_t*)*/ len); + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ index, - out UIntPtr /*(size_t*)*/ len); + UIntPtr /*(size_t)*/ index, + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ bufferLength, - UIntPtr /*(size_t)*/ elementIndex, - byte[] buffer); + UIntPtr /*(size_t)*/ bufferLength, + UIntPtr /*(size_t)*/ elementIndex, + byte[] buffer); public static DOrtGetStringTensorElement OrtGetStringTensorElement; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ - DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo( + IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, + out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( + IntPtr /*(OrtValue*)*/ value, + out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1917,39 +1918,43 @@ out IntPtr /* char** */ buffer public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out IntPtr /*(TensorElementType*)*/ output); public static DOrtGetTensorElementType OrtGetTensorElementType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensions( - IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - long[] dim_values, - UIntPtr dim_values_length); + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + long[] dim_values, + UIntPtr dim_values_length); public static DOrtGetDimensions OrtGetDimensions; /** - * Get the symbolic dimension names for dimensions with a value of -1. - * Order and number of entries is the same as values returned by GetDimensions. - * The name may be empty for an unnamed symbolic dimension. - * e.g. - * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries. - * If the values returned were ['batch', '', ''] it would indicate that - * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions), - * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string), - * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0). - */ + * Get the symbolic dimension names for dimensions with a value of -1. + * Order and number of entries is the same as values returned by GetDimensions. + * The name may be empty for an unnamed symbolic dimension. + * e.g. + * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries. + * If the values returned were ['batch', '', ''] it would indicate that + * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions), + * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string), + * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0). + */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetSymbolicDimensions( - IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - IntPtr[] dim_params, /* const char* values, converted to string by caller */ - UIntPtr dim_params_length); + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + IntPtr[] dim_params, /* const char* values, converted to string by caller */ + UIntPtr dim_params_length); public static DOrtGetSymbolicDimensions OrtGetSymbolicDimensions; @@ -1964,15 +1969,15 @@ out IntPtr /* char** */ buffer */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorShapeElementCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - out UIntPtr /* size_t */ output); + out UIntPtr /* size_t */ output); public static DOrtGetTensorShapeElementCount OrtGetTensorShapeElementCount; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] // The out ortMemoryInfo must not be destroyed/deallocated. The pointer points to an object owned by // the contained Tensor/SparseTensor. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorMemoryInfo(IntPtr /* const OrtValue* */ ortValue, - out IntPtr /* const OrtMemoryInfo** */ ortMemoryInfo); + out IntPtr /* const OrtMemoryInfo** */ ortMemoryInfo); public static DOrtGetTensorMemoryInfo OrtGetTensorMemoryInfo; @@ -1982,10 +1987,12 @@ out IntPtr /* char** */ buffer public static DCastTypeInfoToMapTypeInfo OrtCastTypeInfoToMapTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetMapKeyType(IntPtr /*const OrtMapTypeInfo* */ mapTypeInfo, out IntPtr /*(TensorElementType*)*/ tensorElementType); public static DGetMapKeyType OrtGetMapKeyType; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetMapValueType(IntPtr /* const OrtMapTypeInfo* */ map_type_info, out IntPtr /* OrtTypeInfo** */ type_info); public static DGetMapValueType OrtGetMapValueType; @@ -1996,13 +2003,14 @@ out IntPtr /* char** */ buffer public static DCastTypeInfoToSequenceTypeInfo OrtCastTypeInfoToSequenceTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetSequenceElementType(IntPtr /* const OrtSequenceTypeInfo* */ sequenceTypeInfo, out IntPtr /* OrtTypeInfo** */ elementTypeInfo); public static DGetSequenceElementType OrtGetSequenceElementType; // OptionalTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); public static DOrtCastTypeInfoToOptionalTypeInfo OrtCastTypeInfoToOptionalTypeInfo; @@ -2016,10 +2024,9 @@ out IntPtr /* char** */ buffer public static DOrtReleaseValue OrtReleaseValue; - #endregion - +#endregion - #region Misc API +#region Misc API /// /// Queries all the execution providers supported in the native onnxruntime shared library @@ -2059,8 +2066,8 @@ out IntPtr /* char** */ buffer public static DOrtReleasePrepackedWeightsContainer OrtReleasePrepackedWeightsContainer; - #endregion - } //class NativeMethods +#endregion + } // class NativeMethods // onnxruntime-extensions helpers to make usage simpler. // The onnxruntime-extensions nuget package containing the native library can be optionally added to the app. @@ -2081,7 +2088,5 @@ internal static class OrtExtensionsNativeMethods CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OrtStatus* */ RegisterCustomOps(IntPtr /* OrtSessionOptions* */ sessionOptions, ref OrtApiBase /* OrtApiBase* */ ortApiBase); - - } -} //namespace +} // namespace diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 1868ff509bfc3..68a399f8b9671 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/csharp/tools/ValidateNativeDelegateAttributes.py b/csharp/tools/ValidateNativeDelegateAttributes.py new file mode 100644 index 0000000000000..acd6c173bfeb0 --- /dev/null +++ b/csharp/tools/ValidateNativeDelegateAttributes.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import pathlib + + +def check_all_delegates_have_unmanaged_function_pointer_attribute(file: pathlib.Path): + """ + Check that all 'public delegate' declarations have a matching UnmanagedFunctionPointer attribute. + :param file: C# source file to check. + :return: Number of errors + """ + + print(f"Checking {file!s}") + + errors = 0 + line_num = 0 + with open(str(file.resolve(strict=True))) as f: + prev_line = "" + for line in f.readlines(): + line_num += 1 + + # strip so it's easier to deal with commented out lines. + line = line.strip() # noqa + if line.startswith("public delegate ") and not prev_line.startswith("[UnmanagedFunctionPointer"): + errors += 1 + print(f"Line {line_num} is missing UnmanagedFunctionPointer attribute:\n\t{prev_line}\n\t{line}") + + prev_line = line + + return errors + + +def main(): + arg_parser = argparse.ArgumentParser( + "Script to validate that the native delegates for the ONNX Runtime C# managed projects have the required " + "attributes for iOS AOT. Paths are inferred from the script location." + "Errors of this nature can only be detected at runtime, in a release build, of a Xamarin/MAUI app, " + "on an actual iOS device. Due to that we take extra steps to identify problems early." + ) + + # no real args. just using this to provide description as help message + _ = arg_parser.parse_args() + + # CI needs resolve() as __file__ is a relative path when the script is run there + script_dir = pathlib.Path(__file__).resolve().parent + csharp_root = script_dir.parent + + managed_dir = csharp_root / "src" / "Microsoft.ML.OnnxRuntime" + native_methods = managed_dir / "NativeMethods.shared.cs" + training_native_methods = managed_dir / "Training" / "NativeTrainingMethods.shared.cs" + errors = check_all_delegates_have_unmanaged_function_pointer_attribute(native_methods) + errors += check_all_delegates_have_unmanaged_function_pointer_attribute(training_native_methods) + + if errors: + raise ValueError(f"{errors} errors were found. Please check output for specifics.") + + +if __name__ == "__main__": + main() diff --git a/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj b/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj new file mode 100644 index 0000000000000..098078d2e3683 --- /dev/null +++ b/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj @@ -0,0 +1,15 @@ + + + + netstandard2.0 + $(OnnxRuntimeBuildDirectory)/NativeNuget.nuspec + + diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index 110e484e77d21..5822a805c674e 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -8,13 +8,14 @@ FROM mcr.microsoft.com/cbl-mariner/base/python:3 MAINTAINER Changming Sun "chasun@microsoft.com" ADD . /code -RUN tdnf install -y tar ca-certificates build-essential python3-numpy cmake python3-setuptools python3-wheel python3-pip curl python3-devel +RUN tdnf install -y tar ca-certificates build-essential cmake curl python3-devel python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf +# The latest cmake version in Mariner2 is 3.21, but we need 3.26+ RUN /code/dockerfiles/scripts/install_cmake.sh # Prepare onnxruntime repository & build onnxruntime -RUN cd /code && python3 -m pip install -r tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sync --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) +RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sync --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) FROM mcr.microsoft.com/cbl-mariner/base/python:3 COPY --from=0 /code/build/Linux/Release/dist /root COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt -RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip && python3 -m pip install /root/*.whl && rm -rf /root/*.whl +RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl diff --git a/docs/ABI_Dev_Notes.md b/docs/ABI_Dev_Notes.md index f9b55176cf95b..f85dd9d19a336 100644 --- a/docs/ABI_Dev_Notes.md +++ b/docs/ABI_Dev_Notes.md @@ -4,7 +4,7 @@ Global variables may get constructed or destructed inside "DllMain". There are s ## Thread Local variables Onnxruntime must support explicit linking, where the operating system loads the DLL on demand at runtime, instead of process startup time. This is required by our language bindings like C#/Java. -However, there are some special restrictions on this, If a thread local variable need non-trivial construction, for the threads already exist before onnxruntime.dll is loaded, the variable won't get initialized correctly. So it's better to only access such variables from onnxruntime internal threads, or make these variables function local (Like the magic statics). +However, there are some special restrictions on this. If a thread local variable need non-trivial construction, for the threads already exist before onnxruntime.dll is loaded, the variable won't get initialized correctly. So it's better to only access such variables from onnxruntime internal threads, or make these variables function local (Like the magic statics). ## No undefined symbols diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 95dc8c3cde46c..c73f978bdf404 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -27,6 +27,8 @@ Do not modify directly.* * com.microsoft.DequantizeWithOrder * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul + * com.microsoft.DynamicTimeWarping + * com.microsoft.EPContext * com.microsoft.EmbedLayerNormalization * com.microsoft.ExpandDims * com.microsoft.FastGelu @@ -38,16 +40,21 @@ Do not modify directly.* * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu + * com.microsoft.GemmFloat8 * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm + * com.microsoft.GroupQueryAttention * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat + * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask + * com.microsoft.MoE * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention * com.microsoft.MurmurHash3 @@ -86,8 +93,10 @@ Do not modify directly.* * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft + * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -96,7 +105,9 @@ Do not modify directly.* * com.microsoft.TorchEmbedding * com.microsoft.TransposeMatMul * com.microsoft.Trilu + * com.microsoft.UnfoldTensor * com.microsoft.Unique + * com.microsoft.WhisperBeamSearch * com.microsoft.WordConvEmbedding * experimental com.microsoft.IsAllFinite * experimental com.microsoft.QEmbedLayerNormalization @@ -1130,6 +1141,8 @@ This version of the operator has been available since version 1 of the 'com.micr
The value to be filled in the attention mask. Default value is -10000.0f
num_heads : int (required)
Number of attention heads
+
output_qk : int
+
Need output the cross attention MatMul(Q, K)
past_present_share_buffer : int
Corresponding past and present are same tensor, its size is (batch_size, num_heads, max_sequence_length, head_size)
scale : float
@@ -1163,20 +1176,24 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
-#### Outputs (1 - 3) +#### Outputs (1 - 4)
output : T
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
present_key (optional) : T
-
past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
present_value (optional) : T
-
past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
qk (optional) : V
+
normalized Q * K, of shape (batch_size, num_heads, 1, head_size).
#### Type Constraints
+
V : tensor(float)
+
Constrain qk output types to float32 tensors.
T : tensor(float), tensor(float16)
Constrain input and output types to float tensors.
M : tensor(int32)
@@ -1520,6 +1537,87 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.DynamicTimeWarping** + + Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1). + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
input : F
+
Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N
+
+ +#### Outputs + +
+
output : I
+
Output tensor. shape is [2, x], where max(M, N) <= x < M + N
+
+ +#### Type Constraints + +
+
F : tensor(float)
+
Constrain to float tensors.
+
I : tensor(int32)
+
Constrain to integer types.
+
+ + +### **com.microsoft.EPContext** + + Onnx node container for EP context. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
embed_mode : int
+
1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content.The path is relative to this Onnx file. Default is 1.
+
ep_cache_context : string
+
payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.
+
ep_sdk_version : string
+
(Optional) SDK version used to convert the model.
+
main_context : int
+
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+
notes : string
+
(Optional) Some notes for the model
+
partition_name : string
+
(Optional) partitioned graph name.
+
source : string
+
(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain
+
+ +#### Inputs (1 - ∞) + +
+
inputs (variadic) : T
+
List of tensors for inputs
+
+ +#### Outputs (1 - ∞) + +
+
outputs (variadic) : T
+
One or more outputs, list of tensors for outputs
+
+ +#### Type Constraints + +
+
T : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types.
+
+ + ### **com.microsoft.EmbedLayerNormalization** EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. @@ -2042,6 +2140,71 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmFloat8** + + Generic Gemm for float and float 8. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Activation function, RELU or GELU or NONE (default).
+
alpha : float
+
Scalar multiplier for the product of input tensors A * B.
+
beta : float
+
Scalar multiplier for the product of input bias C.
+
dtype : int
+
Output Type. Same definition as attribute 'to' for operator Cast.
+
transA : int
+
Whether A should be transposed. Float 8 only supprted transA=0.
+
transB : int
+
Whether B should be transposed. Float 8 only supprted transB=1.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+
B : TB
+
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+
C (optional) : TC
+
Input tensor C.
+
scaleA (optional) : TS
+
Scale of tensor A if A is float 8 tensor
+
scaleB (optional) : TS
+
Scale of tensor B if B is float 8 tensor
+
scaleY (optional) : TS
+
Scale of the output tensor if A or B is float 8.
+
+ +#### Outputs + +
+
Y : TR
+
Output tensor of shape (M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input A.
+
TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input B.
+
TC : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input C.
+
TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to result type.
+
TS : tensor(float)
+
Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. @@ -2181,7 +2344,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2218,6 +2381,69 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GroupQueryAttention** + + Group Query Self/Cross Attention. + + Supports different number of heads for q and kv. Only supports causal or local attention. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
kv_num_heads : int (required)
+
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
+
num_heads : int (required)
+
Number of attention heads for q
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
+ +#### Inputs + +
+
query : T
+
Query with shape (batch_size, sequence_length, hidden_size)
+
key : T
+
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
value : T
+
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
past_key (optional) : T
+
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
past_value (optional) : T
+
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
seqlens_k : M
+
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
total_sequence_length : M
+
Scalar tensor of total sequence length (past + new).
+
+ +#### Outputs + +
+
output : T
+
3D output tensor with shape (batch_size, sequence_length, hidden_size)
+
present_key : T
+
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
present_value : T
+
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output to float tensors.
+
M : tensor(int32)
+
Constrain mask to int tensor.
+
+ + ### **com.microsoft.Inverse** #### Version @@ -2347,6 +2573,89 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
training_mode : int
+
Indicate if the ops run in training_mode, by default, False.
+
transB : int
+
Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float/half_float/brain_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. @@ -2479,6 +2788,78 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBits** + + MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] + Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
bits : int (required)
+
number of bits used for weight quantization (default 4)
+
block_size : int (required)
+
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
+ +#### Inputs (3 - 4) + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional data blob
+
scales : T1
+
quantization scale
+
zero_points (optional) : T2
+
quantization zero points
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. @@ -2526,6 +2907,58 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MoE** + + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation_type : string
+
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
k : int
+
Number of top experts to select from expert pool
+
+ +#### Inputs (4 - 6) + +
+
input : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
router_probs : T
+
2D input tensor with shape (num_rows, num_experts)
+
fc1_experts_weights : T
+
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
fc2_experts_weights : T
+
3D input tensor with shape (num_experts, inter_size, hidden_size)
+
fc1_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, hidden_size)
+
+ +#### Outputs + +
+
output : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float or float16 tensors.
+
+ + ### **com.microsoft.MulInteger** Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). @@ -2606,7 +3039,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4568,6 +5001,54 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RotaryEmbedding** + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + that are multiplied to query and key before the inner product of query and key is taken. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1.0
+
+ +#### Inputs + +
+
input : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
+
position_ids : M
+
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+
cos_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
+ +#### Outputs + +
+
output : T
+
tensor with same shape as input.
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int64)
+
Constrain input and output types to integer tensors
+
+ + ### **com.microsoft.SampleOp** Sample echo operator. @@ -4683,6 +5164,72 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion @@ -5078,6 +5625,47 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.UnfoldTensor** + + Returns a tensor which contains all slices of size size from input tensor in the dimension dim. Step between two slices is given by step. If sizedim is the size of dimension dim for input tensor, the size of dimension dim in the returned tensor will be (sizedim - size) / step + 1. An additional dimension of size size is appended in the returned tensor. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
dim : int
+
specify the dimension to unfold
+
size : int (required)
+
specify the size
+
step : int
+
specify the step.
+
+ +#### Inputs + +
+
input : T
+
input tensor
+
+ +#### Outputs + +
+
output : T
+
Output tensor.
+
+ +#### Type Constraints + +
+
T : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)
+
Allow inputs and outputs to be any kind of tensor.
+
+ + ### **com.microsoft.Unique** Finds all the unique values (deduped list) present in the given input tensor. @@ -5124,6 +5712,107 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.WhisperBeamSearch** + + Beam Search for whisper model, especiall with cross_qk features etc. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
decoder : graph (required)
+
Decoder subgraph to execute in a loop.
+
decoder_output_cross_qk : int
+
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
+
decoder_start_token_id : int
+
The id of the token that indicates decoding starts.
+
early_stopping : int
+
early stop or not
+
encoder : graph
+
The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+
eos_token_id : int (required)
+
The id of the end-of-sequence token
+
init_decoder : graph
+
The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs
+
model_type : int
+
Must be 2 for whisper
+
no_repeat_ngram_size : int
+
no repeat ngrams size
+
no_speech_token : int
+
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
pad_token_id : int (required)
+
The id of the padding token
+
vocab_size : int
+
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
+
+ +#### Inputs (5 - 14) + +
+
input_ids : F
+
The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)
+
max_length : I
+
The maximum length of the sequence to be generated. Shape is (1)
+
min_length (optional) : I
+
The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+
num_beams : I
+
Number of beams for beam search. 1 means no beam search. Shape is (1)
+
num_return_sequences : I
+
The number of returned sequences in the batch. Shape is (1)
+
length_penalty (optional) : T
+
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
repetition_penalty (optional) : T
+
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+
vocab_mask (optional) : M
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : M
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
+
attention_mask (optional) : I
+
Custom attention mask. Shape is (batch_size, sequence_length)
+
decoder_input_ids (optional) : I
+
The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)
+
logits_processor (optional) : I
+
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
+
cross_qk_layer_head (optional) : I
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
extra_decoding_ids (optional) : I
+
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
+ +#### Outputs (1 - 5) + +
+
sequences : I
+
Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)
+
sequences_scores (optional) : T
+
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
+
scores (optional) : T
+
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
cross_qk (optional) : V
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
non_speech_probs (optional) : T
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain to float tensors.
+
F : tensor(float), tensor(int32), tensor(float16)
+
Constrain input type to float or int tensors.
+
I : tensor(int32)
+
Constrain to integer types
+
M : tensor(int32)
+
Constrain mask to integer types
+
V : tensor(float)
+
Constrain cross_qk to float32 tensors.
+
+ + ### **com.microsoft.WordConvEmbedding** The WordConvEmbedding takes in a batch of sequence words and embed each word to a vector. diff --git a/docs/FAQ.md b/docs/FAQ.md index e039f4e4e4160..70bedbd02e944 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -2,13 +2,13 @@ Here are some commonly raised questions from users of ONNX Runtime and brought up in [Issues](https://github.com/microsoft/onnxruntime/issues). ## Do the GPU builds support quantized models? -The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](./ONNX_Runtime_Perf_Tuning.md) before determining that quantization is necessary. +The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](https://onnxruntime.ai/docs/performance/tune-performance/) before determining that quantization is necessary. ## How do I change the severity level of the default logger to something other than the default (WARNING)? Setting the severity level to VERBOSE is most useful when debugging errors. Refer to the API documentation: -* Python - [RunOptions.log_severity_level](https://microsoft.github.io/onnxruntime/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) +* Python - [RunOptions.log_severity_level](https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) ``` import onnxruntime as ort ort.set_default_logger_severity(0) diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 3ef3a575f20a1..0147a937db81d 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -22,68 +22,113 @@ Not all models and recipes need this optimizer technique. Imagine if your traini 1. Make sure ONNX Runtime training wheel is installed and correctly configured. 2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO. > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO)) -3. Run the training as usual and redirect all outputs into log file; then stop it after training few steps. -4. Check the logging file, search "Summary", you could possibly find something like this: +3. Run the training as usual; then stop it after training few steps. +4. Check the logs, you could find something like this: ``` - MemoryOptimizer Summary: - User config: - - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= + Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_CONFIG=, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + + + Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time: + export ORTMODULE_MEMORY_OPT_CONFIG=,,... + Note 2: memory saving is calculated based on the 1st batch symbolic dim values: + inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ``` -5. As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case. -6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute. -`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. +5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. +6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `6` `BiasGelu+` related subgraphs are allowed to recompute. +`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. ``` - export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12" + export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:6" # Use comma as separator for enabling more than one subgraphs. ``` -7. Then run the training again, you will see logs like this: +7. Then run the training again, and you will see logs like this: ``` - MemoryOptimizer Summary: - User config: - **FastGelu+:1:12** - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: **Recompute (requested_count=12, actual applied_count=12)** - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= + Memory Optimizer : ON : User config: Reshape+Where+BiasSoftmax+:1:-1, probe level: 1, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : ON : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. +## Optimization Configuration + +The basic optimization unit is represented with a unique `cluster id`, for example `BiasGelu+` is one `cluster id`. +Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving. +Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving. + ## Compromised Recompute -If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. +If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. + +## Memory Optimization Debug Infos + +Using following log level +> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) + +Besides the logs shown in `LogLevel.INFO`, you can also see different node patterns that can apply different optimization options. + +The way we get the table: +- For a specific node, it might has different optimization options, we [generates](../orttraining/orttraining/core/optimizer/memory_optimizer/common.h#L124C26-L124C26) a hash (called `Node Cluster ID`) for the node according to all available optimization options. +- Map all nodes having same `Node Cluster ID` in buckets, each bucket is displayed as one row. + +``` +MemoryInsight Summary - User config: not provided +=========================================================================================================================================== +|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|6 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | +| | Stashed Activations: | +| | - ReuseFreq : Output 0(6), | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(32)*(240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(10240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(32)*(inputs_input_ids_dim1)*(inputs_input_ids_dim1))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|1 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 100% saved | +| | | +| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:2:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 50% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| + +``` ## Notes diff --git a/docs/ORTModule_PythonOp_Notes.md b/docs/ORTModule_PythonOp_Notes.md new file mode 100644 index 0000000000000..1bb549f8fab9d --- /dev/null +++ b/docs/ORTModule_PythonOp_Notes.md @@ -0,0 +1,156 @@ +# ORTModule Custom Autograd Function Support + +## What is autograd Functions? + +`PyTorch` allows users to define customized operators (for its forward and backward implementations) [PyTorch: Defining New autograd Functions](https://github.com/pytorch/tutorials/blob/d98606855d3c8c5bd78d55b95717be5a02960363/beginner_source/examples_autograd/polynomial_custom_function.py#L25). + +There are many such use cases as more optimized deep learning projects keep growing, here we just name a few: +- [NVIDIA/apex](https://github.com/NVIDIA/apex/blob/58acf96915eecd7e13adff61d2c389fba3efede2/apex/transformer/functional/fused_softmax.py#L21) +- [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM/blob/f7727433293427bef04858f67b2889fe9b177d88/megatron/core/tensor_parallel/mappings.py#L220C31-L220C31) +- [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention/blob/3a9fe7b0faaa9d648394026c9c20231c07bf999d/flash_attn/flash_attn_interface.py#L429), +- [openai/triton](https://github.com/openai/triton/blob/424e67e7275f0cb2cd231e7a4d17ff8570530b77/python/tutorials/06-fused-attention.py#L457) +- ... + +Those operators are used in training/evaluation scenarios a lot, where is ORTModule capability overlaps. +To best release ORTModule's acceleration power, we need tolerant and handle those customized operators +from the to-onnx conversion, to backward graph building, and also its execution in runtime as a full lifecycle. + +## How ORTModule support autograd.Function? + +The way we have here is through introduced `PythonOp`/`PythonOpGrad` MS domain operators in `ONNX Runtime`, +- Map autograd Function (`prim::PythonOp` in `PyTorch`) to `PythonOp` in `ONNX Runtime` during model export by [registering customized export function](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py#L69C16-L69C16) + ``` + class ScalarAndTupleFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, alpha, beta, gamma): + ctx.save_for_backward(input) + ctx.alpha = alpha + ctx.beta = beta + ctx.gamma = gamma + return alpha * beta[0] * beta[1] * gamma * input.clamp(min=0) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + alpha = ctx.alpha + beta = ctx.beta + gamma = ctx.gamma + grad_input = grad_output.clone() + grad_input[input < 0] = 0 + return alpha * beta[0] * beta[1] * gamma * grad_input, None, None, None + ``` + The example above shows a customized function taking 4 inputs (despite of ctx), the first input is a tensor [exporter treats it as input for `PythonOp`](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L174), + the others are scalars, export function will convert all such non-tensor inputs to constant and [stores + in `PythonOp`'s attributes](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L272). Things to be noted here: if the non-tensor + input is one of those types "bool scalar, int scalar, float scalar, bool tuple, int tuple, float tuple", they will be + stored in corresponding attributes; otherwise, they will be treated a `object` and the object address stored in `input_pointer_scalars` ([reference count will be increased](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py#L250C27-L250C27) also to make sure it exists during model run). +- [PythonOp kernel](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc#L38) is responsible to run the `forward` interface user defined through [forward runner](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py#L409). +Similarly, [PythonOpGrad kernel](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/training_ops/cuda/torch/torch_custom_function_kernel.cc#L49) is responsible to run the `backward` interface user defined through [backward runner](https://github.com/microsoft/onnxruntime/blob/c2bd5b70b29eb3c687c5497696e7b0a1930604d3/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py#L554). + +Currently, for training python wheel, `PythonOp` support is by default enabled, users don't need to be aware of it. As long as the +defined torch.autograd.Function is working in `PyTorch` run, it should be runnable with `ORTModule`. If you need to enable it or +disable it explicitly, refer to the [wiki](https://github.com/microsoft/onnxruntime/blob/main/docs/ORTModule_Training_Guidelines.md#ortmodule_enable_custom_autograd). + + + +## Known Issues and Workaround + +PyTorch Versions +- Minimum version 1.9 (introduced "Support registering custom export for `prim::PythonOp`` from torch.autograd.Function ([#55630](https://github.com/pytorch/pytorch/pull/55630)) ([#57600](https://github.com/pytorch/pytorch/pull/57600))") +- If the static forward function has only one output, any version of Pytorch 1.9 is fine. Otherwise, a PyTorch version containing [this commit](https://github.com/pytorch/pytorch/commit/a55cae3d37e0f7852e391886c3904307caa4d06d) is required. +- [Throw _Map_base::at Exception](https://github.com/pytorch/pytorch/issues/88286), export errors like this: + ``` + RuntimeError: There was an error while exporting the PyTorch model to ONNX: + + Traceback (most recent call last): + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 316, in get_exception_as_string + raise exception + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 425, in _get_exported_model + torch.onnx.export( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export + _export( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 1548, in _export + graph, params_dict, torch_out = _model_to_graph( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph + graph, params, torch_out, module = _create_jit_graph(model, args) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph + graph, torch_out = _trace_and_get_graph_from_model(model, args) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + ... + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/deepspeed-0.9.5+95680ca-py3.8.egg/deepspeed/runtime/zero/parameter_offload.py", line 632, in _ort_post_forward_module_hook + a = ORTPostForwardwardFunction.apply(module, _post_forward_module_hook, _ort_run_before_backward_function, len(input), len(output), *input_and_output) + File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply + return super().apply(*args, **kwargs) # type: ignore[misc] + RuntimeError: _Map_base::at + ``` + Resolution: upgrade `PyTorch` to new versions containing [this commit](https://github.com/thiagocrepaldi/pytorch/commit/3d3da109e3afa617c513e78aa999f5a1f44ffbce), when export param `autograd_inlining` is [set to false](https://github.com/microsoft/onnxruntime/blob/0e2782438a65b97919f15af14d2a4ada361157b6/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py#L387C26-L387C26) to skip this error. +- "Tried to trace <__torch__.torch.classes.c10d.ProcessGroup object at 0x2969c520> but it is not part of the active trace" + This usually happens when torch.autograd.Function's forward function used `PyTorch` collective calls and pass the group explicitly. + ``` + RuntimeError: There was an error while exporting the PyTorch model to ONNX: + + Traceback (most recent call last): + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 324, in get_exception_as_string + raise exception + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 342, in _get_exported_model + torch.onnx.export( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 507, in export + _export( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1567, in _export + graph, params_dict, torch_out = _model_to_graph( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1124, in _model_to_graph + graph, params, torch_out, module = _create_jit_graph(model, args) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 1000, in _create_jit_graph + graph, torch_out = _trace_and_get_graph_from_model(model, args) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 1269, in _get_trace_graph + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/_trace.py", line 128, in forward + graph, out = torch._C._create_graph_by_tracing( + ... + File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/parameter_offload.py", line 640, in _ort_pre_forward_module_hook + rets = ORTPreForwardwardFunction.apply(self, module, _ort_run_after_backward_function, *inputs) + ... + File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/parameter_offload.py", line 823, in pre_sub_module_forward_function + param_coordinator.fetch_sub_module(sub_module, forward=True) + ... + File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2841, in all_gather_into_tensor + work = group._allgather_base(output_tensor, input_tensor) + RuntimeError: Tried to trace <__torch__.torch.classes.c10d.ProcessGroup object at 0x56250ad114a0> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced. + ``` + Resolution: modify the autograd.Function, to skip the run the collection operator during onnx export, here is an example. + ```python + # Pre + def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()): + return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug) + + # Workaround + from typing import Any, List + class DummyWork(torch.distributed.distributed_c10d.Work): + def is_completed(self) -> bool: + return True + def is_success(self) -> bool: + return True + def exception(self) -> Any: + return None + def wait(self, timeout: timedelta = timedelta) -> bool: + return True + def source_rank(self) -> int: + return 0 + def _source_rank(self) -> int: + return 0 + def result(self) -> List[torch.Tensor]: + return [] + def synchronize(self): + pass + + def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()): + if torch.onnx.is_in_onnx_export(): + return DummyWork() + + return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug) + ``` diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 5350988b20964..7fa89cca381d9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -49,6 +49,90 @@ More options for **developers**. ``` Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/options.py) for more details. +#### Log Level Explanations + + + + + + + + + + + + + + + + + + + + + + + + +
Log LevelDescription
+ +`FATAL` | `ERROR` | `WARNING` (For Users) + +`WARNING` is the default and recommended level for +
users.
+
+ +- ONNX Runtime backend log level - `FATAL` | `ERROR` | `WARNING`. +- ORTModule log level - `FATAL` | `ERROR` | `WARNING`. +- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only). +- PyTorch exporter export logs filtering is `ON`. +- PyTorch exporter verbose logs (including tracing graph) filtering is `ON`. + +
+ +`INFO` (For Users | ORT Developers) + +`INFO` is used for collecting experimental +
feature stats, or a little bit more error messages.
+
+ +- ONNX Runtime backend log level - `WARNING`. +- ORTModule log level - `INFO`. +- Rank-0 log filtering is `ON` (e.g. logging on rank-0-only). +- PyTorch exporter export logs filtering is `ON`. +- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`. + +
+ +`DEVINFO` (For ORT Developers) + +`DEVINFO` is the recommended level for +
debugging purposes.
+
+ +- ONNX Runtime backend log level - `INFO`. +- ORTModule log level - `INFO`. +- Rank-0 log filtering is `OFF` (e.g. logging on all ranks). +- PyTorch exporter export logs filtering is `OFF`. +- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`. + +
+ +`VERBOSE` (For ORT Developers) + +`VERBOSE` is the last resort for debugging +
hard problems.
+
+ +- ONNX Runtime backend log level - `VERBOSE`. +- ORTModule log level - `VERBOSE`. +- Rank-0 log filtering is `OFF` (e.g. logging on all ranks). +- PyTorch exporter export logs filtering is `OFF`. +- PyTorch exporter verbose logs (including tracing graph) filtering is `OFF`. + +
+ + ### 2.1 Environment Variables `ORTModule` provides environment variables targeting different use cases. @@ -185,6 +269,15 @@ data sparsity based performance optimizations. unset ORTMODULE_CACHE_DIR # Disable ``` +#### ORTMODULE_USE_EFFICIENT_ATTENTION + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_EFFICIENT_ATTENTION=1 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* @@ -313,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results ``` +#### ORTMODULE_USE_FLASH_ATTENTION + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_FLASH_ATTENTION=1 + ``` + #### ORTMODULE_TRITON_DEBUG - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 33c187a28b62e..edf249a816923 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,6 +25,7 @@ Do not modify directly.* |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| |ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| @@ -67,7 +68,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|20+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| @@ -78,7 +80,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -136,7 +138,8 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| +|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HannWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| @@ -155,8 +158,10 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| |||[1, 12]|**T** = tensor(float)| @@ -368,7 +373,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)| +|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -453,9 +458,11 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -475,14 +482,17 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -513,10 +523,8 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -657,7 +665,7 @@ Do not modify directly.* |Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| @@ -687,39 +695,26 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -807,7 +802,7 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 15]|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| | | @@ -826,22 +821,28 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| |DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| +|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -860,11 +861,15 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | @@ -935,7 +940,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1244,6 +1249,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dca..6ef16e1378139 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINO™ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 0fd04f28d44b7..dd607cbbc6952 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ { val = static_cast((b & 0x80000000) >> 24); // sign if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { + // the highest available value val |= 0x7F; } else { - // infinity + // NaN val = 0x80; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN @@ -362,8 +363,10 @@ struct Float8E5M2 { val = (b & 0x80000000) >> 24; // sign if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { + // the highest available value val |= 0x7B; } else { + // the infinity val |= 0x7C; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index 700e1edc0cb7d..e7ac01947af41 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -10,20 +10,6 @@ #include "core/common/gsl.h" #endif -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif - -#if ORT_HAVE_ATTRIBUTE(nodiscard) -#define MUST_USE_RESULT [[nodiscard]] -#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) -#define MUST_USE_RESULT __attribute__((warn_unused_result)) -#else -#define MUST_USE_RESULT -#endif - class IMLOpKernel; namespace onnxruntime { @@ -43,14 +29,26 @@ class OpNodeProtoHelper { Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema */ template - MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; + Status GetAttr(const std::string& name, T* value) const; + + /** + Get a single attribute + Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema + Throws if an attribute with the specified type doesn't exist + */ + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } /** Get a single attribute Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - T GetAttrOrDefault(const std::string& name, const T& default_value) const { + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { T tmp; return GetAttr(name, &tmp).IsOK() ? tmp : default_value; } @@ -70,7 +68,8 @@ class OpNodeProtoHelper { Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { std::vector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -87,11 +86,12 @@ class OpNodeProtoHelper { /// Attribute data in a span, out parameter /// Status template - MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; + Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; - MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const; + Status GetAttrs(const std::string& name, TensorShapeVector& out) const; - MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const { + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { TensorShapeVector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -100,43 +100,43 @@ class OpNodeProtoHelper { Get repeated attributes */ template - MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; + Status GetAttrs(const std::string& name, std::vector& values) const; template - MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; + Status GetAttrs(const std::string& name, gsl::span values) const; - MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name, - std::vector>& refs) const; + Status GetAttrsStringRefs(const std::string& name, + std::vector>& refs) const; - uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - uint32_t GetInputCount() const { + [[nodiscard]] uint32_t GetInputCount() const { return gsl::narrow_cast(impl_->getNumInputs()); } - uint32_t GetOutputCount() const { + [[nodiscard]] uint32_t GetOutputCount() const { return gsl::narrow_cast(impl_->getNumOutputs()); } - const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { return impl_->getInputType(index); } - const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { // Work around lack of a const method from the onnx InferenceContext interface return const_cast(impl_)->getOutputType(index); } // Try to query an attribute, returning nullptr if it doesn't exist - const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { return impl_->getAttribute(name); } - const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); ORT_ENFORCE(attr != nullptr); return attr; diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 7f3f26fa4aa02..a867ab6066485 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -3,8 +3,9 @@ #pragma once -#include +#include #include +#include #include #include @@ -37,7 +38,8 @@ namespace onnxruntime { class Tensor final { public: - // NB! Removing Create() methods returning unique_ptr. Still available in other EPs that are dynamically linked. + // NB! Removing Create() methods returning unique_ptr. + // Still available in other EPs that are dynamically linked. // Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine. // Use InitOrtValue() methods to allocate for OrtValue. @@ -45,105 +47,104 @@ class Tensor final { /** * Create tensor with given type, shape, pre-allocated memory and allocator info. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * \param p_type Data type of the tensor + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor does not own the data and will not delete it - * \param alloc Where the buffer('p_data') was allocated from + * Tensor does not own the data and will not delete it + * \param location Memory info for location of p_data. * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset = 0, gsl::span strides = {}); - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, const OrtMemoryInfo& location, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor who own the pre-allocated buffer. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, std::shared_ptr allocator, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - static size_t CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides = {}); - /** - * Deprecated. The original design is this Tensor class won't do any allocation / release. - * However, this function will allocate the buffer for the shape, and do placement new if p_type is string tensor. - */ - Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType elt_type, - const TensorShape& shape, - std::shared_ptr allocator, - OrtValue& ort_value, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); - - /** - * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the pre-allocated memory. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * However, this function will de-allocate the buffer upon the tensor getting destructed. - * \param p_type Data type of the tensor + * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the + * pre-allocated memory. The Tensor will take over ownership of p_data. + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor will own the memory and will delete it when the tensor instance is destructed. + * Tensor will own the memory and will delete it when the tensor instance is destructed. * \param deleter Allocator used to free the pre-allocated memory * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset = 0, gsl::span strides = {}); + /// + /// Create a Tensor that allocates and owns the buffer required for the specified shape. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator to use to create and free buffer. + Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator); + ~Tensor(); // Move is allowed ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); Tensor(Tensor&& other) noexcept; - Tensor& operator=(Tensor&& other) noexcept; + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Tensor data. + /// Memory info for location of p_data. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap which will take over ownership of the pre-allocated buffer. + /// + /// Data type of the tensor elements. + /// + /// Tensor data. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, + std::shared_ptr allocator, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// The Tensor instance will allocate and own the data required for `shape`. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, + OrtValue& ort_value); + + /// + /// Initializes OrtValue with an existing Tensor. + /// + /// Tensor. + /// OrtValue to populate with Tensor. + static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); + + /// + /// Calculate the required storage for the tensor. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Bytes required. + static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape); + /** Returns the data type. */ @@ -294,7 +295,7 @@ class Tensor final { // More API methods. private: - void Init(MLDataType p_type, + void Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index b3783696b8d78..82a1c1de83523 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -2,34 +2,17 @@ // Licensed under the MIT License. #pragma once -#include -#include + #include -#include #include -#include "core/common/gsl.h" -#include "onnxruntime_config.h" - -#ifndef DISABLE_ABSEIL -// Need to include abseil inlined_vector.h header directly here -// as hash tables cause CUDA 10.2 compilers to fail. inlined_vector.h is fine. -#ifdef _MSC_VER -#pragma warning(push) -// C4127: conditional expression is constant -#pragma warning(disable : 4127) -// C4324: structure was padded due to alignment specifier -// Usage of alignas causes some internal padding in places. -#pragma warning(disable : 4324) -#endif - -#include - -#ifdef _MSC_VER -#pragma warning(pop) -#endif -#endif // DISABLE_ABSEIL +#include +#include +#include +#include "core/common/gsl.h" +#include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" +#include "onnxruntime_config.h" namespace onnxruntime { #ifdef __GNUC__ @@ -41,18 +24,10 @@ namespace onnxruntime { constexpr size_t kTensorShapeSmallBufferElementsSize = 5; -#ifndef DISABLE_ABSEIL // Use this type to build a shape and then create TensorShape. -using TensorShapeVector = absl::InlinedVector; -#else -class TensorShapeVector : public std::vector { - using Base = std::vector; - - public: - using Base::Base; -}; - -#endif // DISABLE_ABSEIL +// We opt to re-use a common instantiation instead of a typedef with kTensorShapeSmallBufferElementsSize +// To reduce on binary size. +using TensorShapeVector = InlinedVector; inline TensorShapeVector ToShapeVector(const gsl::span& span) { TensorShapeVector out; @@ -194,9 +169,7 @@ class TensorShape { friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate }; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif + // operator<< to nicely output to a stream std::ostream& operator<<(std::ostream& out, const TensorShape& shape); diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f153e88909b8d..22827d43b200f 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -83,10 +84,10 @@ class Node { gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { - Init(std::string{name}, std::string{op_type}, std::string{description}, - std::vector{input_args.begin(), input_args.end()}, - std::vector{output_args.begin(), output_args.end()}, - attributes, std::string{domain}); + Init(name, op_type, description, + input_args, + output_args, + attributes, domain); } #endif @@ -396,6 +397,10 @@ class Node { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ bool ClearAttribute(const std::string& attr_name); + + /** Gets the Node's mutable attributes. */ + NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** @@ -405,8 +410,6 @@ class Node { int PruneRemovableAttributes(gsl::span removable_attributes); #if !defined(ORT_MINIMAL_BUILD) - /** Gets the Node's mutable attributes. */ - NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve. @param attr_name Attribute name for the GraphProto attribute. @@ -440,6 +443,13 @@ class Node { return attr_to_subgraph_map_; } + /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node. + * @returns a mutable map of mutable subgraphs. + */ + std::unordered_map>& GetMutableMapOfAttributeNameToSubgraph() { + return attr_to_subgraph_map_; + } + /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node. @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance. nullptr if the Node has no subgraphs. @@ -563,13 +573,13 @@ class Node { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - void Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, + void Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain); + std::string_view domain); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -585,7 +595,7 @@ class Node { // create a Graph instance for an attribute that contains a GraphProto void CreateSubgraph(const std::string& attr_name); - const std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } + std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } // validate and update the input arg count common::Status UpdateInputArgCount(); @@ -658,7 +668,7 @@ class Node { The Graph representation containing the graph inputs and outputs, the Node instances, and the edges connecting the nodes. */ -class Graph { +class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve existing data member order for readability public: /** Gets the Graph name. */ const std::string& Name() const noexcept; @@ -1133,6 +1143,26 @@ class Graph { */ Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); + /** + Directly insert one of the If node branches into this Graph. + `If` node condition must be a constant. The function would + rename the nodes of the corresponding subgraph to make sure there is no conflict. + + Explicit and implicit inputs references stay the same. + + All of the outputs of the subgraph being inlined should be renamed + to the outputs of the If node. + + The function will process any subgraphs in each of the nodes being inlined, + and will rename any references to the new names introduced. + + @param condition_value If condition value + @param if_node - the node that contains the graph_to_inline. This node is going + to be deleted and replaced by the corresponding graph (either then or else) + @param logger + */ + Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger); + /** Directly insert the nodes in the function Node provided into this Graph. The Graph needs to be Resolve()d after this call. @@ -1141,8 +1171,22 @@ class Graph { */ Status InlineFunction(Node& node); + /** + Directly insert the nodes in the function proto provided into the graph. + The function converts Constant nodes into the initializers in the graph. + It then creates a node in the graph for each of the function nodes. + All of the names are expected to be specialized, and, therefore unique. + See function_utils::Specialize(). + + The Graph needs to be Resolve()d after this call. + @param func_to_inline + @returns Status indicating success or providing an error message. + */ + + Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline); + /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will - be used as a GraphProto attribute in another Node.. + be used as a GraphProto attribute in another Node. e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs when the Graph is resolved. @@ -1391,6 +1435,13 @@ class Graph { Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); + /** Helper that converts and adds constant node proto to an initializer in the graph. + @param constant_node_proto Constant node to convert + @param new_name use the new name for the initializer. + */ + Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto, + std::optional new_name); + #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index a57385f6e23f1..f9b694efb936f 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -278,7 +278,7 @@ class ThreadPoolProfiler { int num_threads_; #ifdef _MSC_VER #pragma warning(push) -// C4324: structure was padded due to alignment specifier + // C4324: structure was padded due to alignment specifier #pragma warning(disable : 4324) #endif // _MSC_VER struct ORT_ALIGN_TO_AVOID_FALSE_SHARING ChildThreadStat { diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 13c176dad3cc5..d73d551920d47 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -1,5 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// This header is to expose a context for cuda custom ops. +// By the context, a custom cuda operator could fetch existing resources, +// such as cuda stream and cudnn handle, for reusing. + +// For concrete usage, pls find page here: +// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm + #pragma once #define ORT_CUDA_CTX @@ -19,8 +27,9 @@ struct CudaContext : public CustomOpContext { cudaStream_t cuda_stream = {}; cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; - void Init(const OrtKernelContext& kernel_ctx) override { + void Init(const OrtKernelContext& kernel_ctx) { const auto& ort_api = Ort::GetApi(); void* resource = {}; OrtStatus* status = nullptr; @@ -44,6 +53,36 @@ struct CudaContext : public CustomOpContext { ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } cublas_handle = reinterpret_cast(resource); + + resource = {}; + status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + if (status) { + ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + deferred_cpu_allocator = reinterpret_cast(resource); + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } } }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index eaf0e5337b8b6..82bb8ba83be4a 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include + #include "onnxruntime_c_api.h" #include "core/framework/arena_extend_strategy.h" @@ -32,4 +35,6 @@ struct OrtCUDAProviderOptionsV2 { int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp. int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. // The strict mode has better accuracy but lower performance. + int prefer_nhwc = 0; // make the CUDA EP NHWC preferred + int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index e46fc5b4219dd..8c3ed46ade6a1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,10 +3,11 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 1 +#define ORT_CUDA_RESOUCE_VERSION 2 enum CudaResource : int { cuda_stream_t = cuda_resource_offset, cudnn_handle_t, - cublas_handle_t + cublas_handle_t, + deferred_cpu_allocator_t, }; \ No newline at end of file diff --git a/include/onnxruntime/core/providers/custom_op_context.h b/include/onnxruntime/core/providers/custom_op_context.h index 547f9a90aff85..8f3d2476d4fdb 100644 --- a/include/onnxruntime/core/providers/custom_op_context.h +++ b/include/onnxruntime/core/providers/custom_op_context.h @@ -3,11 +3,8 @@ #pragma once -#include - // CustomOpContext defines an interface allowing a custom op to access ep-specific resources. struct CustomOpContext { CustomOpContext() = default; virtual ~CustomOpContext(){}; - virtual void Init(const OrtKernelContext&){}; }; \ No newline at end of file diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 0782d2d9ed760..7d7f05193f486 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -30,6 +30,35 @@ typedef struct IDMLDevice IDMLDevice; extern "C" { #endif +enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +}; + +enum OrtDmlDeviceFilter : uint32_t { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif +}; + +inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } +inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); } +inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); } +inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); } +inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); } +inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } +inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } + +struct OrtDmlDeviceOptions { + OrtDmlPerformancePreference Preference; + OrtDmlDeviceFilter Filter; +}; + /** * [[deprecated]] * This export is deprecated. @@ -99,6 +128,13 @@ struct OrtDmlApi { * This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP. */ ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource); + + /** + * SessionOptionsAppendExecutionProvider_DML2 + * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference + * (high power, low power, or default) and a device filter (None, GPU, or NPU). + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/providers/rocm/rocm_context.h b/include/onnxruntime/core/providers/rocm/rocm_context.h index ff62094f3f439..5f04289a8c6e0 100644 --- a/include/onnxruntime/core/providers/rocm/rocm_context.h +++ b/include/onnxruntime/core/providers/rocm/rocm_context.h @@ -18,7 +18,7 @@ struct RocmContext : public CustomOpContext { miopenHandle_t miopen_handle = {}; rocblas_handle rblas_handle = {}; - void Init(const OrtKernelContext& kernel_ctx) override { + void Init(const OrtKernelContext& kernel_ctx) { const auto& ort_api = Ort::GetApi(); void* resource = {}; OrtStatus* status = nullptr; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h deleted file mode 100644 index 44debc901cb77..0000000000000 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index e7d0f9f03ade9..680ce1cc5b9a2 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,38 +11,39 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. ///
struct OrtTensorRTProviderOptionsV2 { - int device_id; // cuda device id. - int has_user_compute_stream; // indicator of user specified CUDA compute stream. - void* user_compute_stream; // user specified CUDA compute stream. - int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability - int trt_min_subgraph_size; // minimum size of TensorRT subgraphs - size_t trt_max_workspace_size; // maximum workspace size for TensorRT. - int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true - int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true - const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. - int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true - int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true - int trt_dla_core; // DLA core number. Default 0 - int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true - int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path; // specify engine cache path - int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true - const char* trt_engine_decryption_lib_path; // specify engine decryption library path - int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true - int trt_context_memory_sharing_enable; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true - int trt_layer_norm_fp32_fallback; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true - int trt_timing_cache_enable; // enable TensorRT timing cache. Default 0 = false, nonzero = true - int trt_force_timing_cache; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true - int trt_detailed_build_log; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true - int trt_build_heuristics_enable; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true - int trt_sparsity_enable; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true - int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5] - int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics - const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default - // tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS" - const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths - const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with - const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with - const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with - int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT + int device_id{0}; // cuda device id. + int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. + void* user_compute_stream{nullptr}; // user specified CUDA compute stream. + int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability + int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs + size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT. + int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true + int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true + const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name. + int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true + int trt_dla_enable{0}; // enable DLA. Default 0 = false, nonzero = true + int trt_dla_core{0}; // DLA core number. Default 0 + int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true + int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true + const char* trt_engine_cache_path{nullptr}; // specify engine cache path, defaults to the working directory + int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true + const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path + int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true + int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true + int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true + int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true + const char* trt_timing_cache_path{nullptr}; // specify timing cache path, if none is provided the trt_engine_cache_path is used + int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true + int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true + int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true + int trt_sparsity_enable{0}; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true + int trt_builder_optimization_level{3}; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5] + int trt_auxiliary_streams{-1}; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics + const char* trt_tactic_sources{nullptr}; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default + // tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS" + const char* trt_extra_plugin_lib_paths{nullptr}; // specify extra TensorRT plugin library paths + const char* trt_profile_min_shapes{nullptr}; // Specify the range of the input shapes to build the engine with + const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with + const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with + int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e483c67a0cfe6..cddad732104ed 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -299,6 +299,7 @@ ORT_RUNTIME_CLASS(DnnlProviderOptions); ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); +ORT_RUNTIME_CLASS(ShapeInferContext); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -598,9 +599,11 @@ typedef struct OrtTensorRTProviderOptions { * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX */ typedef struct OrtMIGraphXProviderOptions { - int device_id; // hip device id. - int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true - int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true + int device_id; // hip device id. + int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options @@ -610,7 +613,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -623,7 +626,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default @@ -745,6 +748,8 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv @@ -755,17 +760,20 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. * \param[in] logging_function A pointer to a logging function. * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. + * `logging_function`. This parameter is optional. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Enable Telemetry * @@ -3592,6 +3600,18 @@ struct OrtApi { * "rpc_control_latency": QNN RPC control latency. * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model. + * 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context. + * The path is relative path to the ONNX skeleton model file. + * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. + * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal". + * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options: + * - "0": Default. + * - "1": Faster preparation time, less optimal graph. + * - "2": Longer preparation time, more optimal graph. + * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -4413,6 +4433,93 @@ struct OrtApi { * \since Version 1.16. */ ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource); + + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); + + /** + * Get number of input from OrtShapeInferContext + * + * \param[in] context + * \param[out] out The number of inputs + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); + + /** + * Get type and shape info of an input + * + * \param[in] context + * \param[in] index The index of the input + * \param[out] info Type shape info of the input + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); + + /** + * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. + * + * \param[in] context + * \param[in] attr_name Name of the attribute + * \param[out] attr Handle of the attribute fetched + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); + + /** + * Set type and shape info of an ouput + * + * \param[in] context + * \param[in] index The index of the ouput + * \param[out] info Type shape info of the output + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** + * Set symbolic shape to type shape info + * + * \param[in] info Type shape info + * \param[in] dim_params Symbolic strings + * \param[in] dim_params_length Number of strings + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); + + /** + * Read contents of an attribute to data + * + * \param[in] op_attr + * \param[in] type Attribute type + * \param[out] data Memory address to save raw content of the attribute + * \param[in] len Number of bytes allowed to store in data + * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); }; /* @@ -4504,6 +4611,12 @@ struct OrtCustomOp { // Perform the computation step. OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + + OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* @@ -4544,6 +4657,14 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessio */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena); +/* + * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists + * + * \param device_id CUDA device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); + #ifdef __cplusplus } #endif diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 47356c3fe3608..92c25d8688b66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,6 +2055,7 @@ struct KernelContext { void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } private: OrtKernelContext* ctx_; @@ -2155,6 +2156,80 @@ struct Op : detail::Base { size_t output_count); }; +/// +/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. +/// +struct ShapeInferContext { + struct SymbolicInteger { + SymbolicInteger(int64_t i) : i_(i), is_int_(true){}; + SymbolicInteger(const char* s) : s_(s), is_int_(false){}; + SymbolicInteger(const SymbolicInteger&) = default; + SymbolicInteger(SymbolicInteger&&) = default; + + SymbolicInteger& operator=(const SymbolicInteger&) = default; + SymbolicInteger& operator=(SymbolicInteger&&) = default; + + bool operator==(const SymbolicInteger& dim) const { + if (is_int_ == dim.is_int_) { + if (is_int_) { + return i_ == dim.i_; + } else { + return std::string{s_} == std::string{dim.s_}; + } + } + return false; + } + + bool IsInt() const { return is_int_; } + int64_t AsInt() const { return i_; } + const char* AsSym() const { return s_; } + + static constexpr int INVALID_INT_DIM = -2; + + private: + union { + int64_t i_; + const char* s_; + }; + bool is_int_; + }; + + using Shape = std::vector; + + ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); + + const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } + + size_t GetInputCount() const { return input_shapes_.size(); } + + Status SetOutputShape(size_t indice, const Shape& shape); + + int64_t GetAttrInt(const char* attr_name); + + using Ints = std::vector; + Ints GetAttrInts(const char* attr_name); + + float GetAttrFloat(const char* attr_name); + + using Floats = std::vector; + Floats GetAttrFloats(const char* attr_name); + + std::string GetAttrString(const char* attr_name); + + using Strings = std::vector; + Strings GetAttrStrings(const char* attr_name); + + private: + const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + const OrtApi* ort_api_; + OrtShapeInferContext* ctx_; + std::vector input_shapes_; +}; + +using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); + +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2205,6 +2280,16 @@ struct CustomOpBase : OrtCustomOp { static_cast(op_kernel)->Compute(context); }; } + + SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2256,9 +2341,26 @@ struct CustomOpBase : OrtCustomOp { return std::vector{}; } + template + decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInferFn(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } + protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 22172832cde8e..860a27fc73f79 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -8,6 +8,15 @@ // the main C++ file with implementation details. #include +#include + +#define RETURN_ON_API_FAIL(expression) \ + { \ + auto err = (expression); \ + if (err) { \ + return Status(err); \ + } \ + } namespace Ort { @@ -1883,4 +1892,154 @@ void CustomOpBase::GetSessionConfigs(std::unordered_ma } } +inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, + OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) { + size_t input_count = 0; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count)); + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + OrtTensorTypeAndShapeInfo* info{}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info)); + TensorTypeAndShapeInfo type_shape_info(info); + auto integer_shape = type_shape_info.GetShape(); + std::vector symbolic_shape(integer_shape.size(), {}); + type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size()); + Shape shape; + for (size_t ith = 0; ith < integer_shape.size(); ++ith) { + if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) { + shape.emplace_back(symbolic_shape[ith]); + } else { + shape.emplace_back(integer_shape[ith]); + } + } + input_shapes_.push_back(std::move(shape)); + type_shape_info.release(); + } +} + +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { + OrtTensorTypeAndShapeInfo* info = {}; + RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + + using InfoPtr = std::unique_ptr>; + + InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) { + ort_api_->ReleaseTensorTypeAndShapeInfo(obj); + }); + + std::vector integer_dims; + std::vector symbolic_dims; + + for (const auto dim : shape) { + if (dim.IsInt()) { + integer_dims.push_back(dim.IsInt()); + symbolic_dims.push_back(""); + } else { + if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) { + ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT); + } + integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM); + symbolic_dims.push_back(dim.AsSym()); + } + } + + RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); + RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); + RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); + return Status{nullptr}; +} + +inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); + return i; +} + +inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); + if (status) { + size_t num_i = out / sizeof(int64_t); + ShapeInferContext::Ints ints(num_i, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); + return ints; + } else { + return {i}; + } +} + +inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); + return f; +} + +inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); + if (status) { + size_t num_f = out / sizeof(float); + ShapeInferContext::Floats floats(num_f, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); + return floats; + } else { + return {f}; + } +} + +inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); + return {chars.data()}; + } else { + return {c}; + } +} + +inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); + ShapeInferContext::Strings strings; + char* char_st = chars.data(); + char* char_ed = char_st + out; + while (char_st < char_ed) { + strings.emplace_back(char_st); + while (*char_st != '\0') { + char_st++; + } + char_st++; + } + return strings; + } else { + return {std::string{c}}; + } +} + +inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { + const OrtOpAttr* attr_hdl = {}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); + return attr_hdl; +} + } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index fd42824e81a56..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -18,22 +18,81 @@ #include "onnxruntime_cxx_api.h" #include #include +#include #include namespace Ort { namespace Custom { -class TensorBase { +class ArgBase { public: - TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {} - virtual ~TensorBase() {} + ArgBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} + virtual ~ArgBase(){}; + + protected: + struct KernelContext ctx_; + size_t indice_; + bool is_input_; +}; + +using ArgPtr = std::unique_ptr; +using ArgPtrs = std::vector; + +class TensorBase : public ArgBase { + public: + TensorBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ArgBase(ctx, indice, is_input) {} + operator bool() const { return shape_.has_value(); } + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + + ONNXTensorElementDataType Type() const { + return type_; + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + bool IsCpuTensor() const { + return strcmp("Cpu", mem_type_) == 0; + } + + virtual const void* DataRaw() const = 0; + virtual size_t SizeInBytes() const = 0; + protected: - struct KernelContext ctx_; std::optional> shape_; + ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + const char* mem_type_ = "Cpu"; }; template @@ -48,13 +107,14 @@ struct Span { T operator[](size_t indice) const { return data_[indice]; } + const T* data() const { return data_; } }; template class Tensor : public TensorBase { public: using TT = typename std::remove_reference::type; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -64,19 +124,6 @@ class Tensor : public TensorBase { shape_ = type_shape_info.GetShape(); } } - const std::vector& Shape() const { - if (!shape_.has_value()) { - ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return shape_.value(); - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - return 0; - } - } const TT* Data() const { return reinterpret_cast(const_value_.GetTensorRawData()); } @@ -104,10 +151,15 @@ class Tensor : public TensorBase { } return *Data(); } + const void* DataRaw() const override { + return reinterpret_cast(Data()); + } + + size_t SizeInBytes() const override { + return sizeof(TT) * static_cast(NumberOfElement()); + } private: - size_t indice_; - bool is_input_; ConstValue const_value_; // for input TT* data_{}; // for output Span span_; @@ -118,7 +170,7 @@ class Tensor : public TensorBase { public: using strings = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -147,16 +199,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const strings& Data() const { return input_strings_; } + const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_strings_[0].c_str()); + } + size_t SizeInBytes() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -179,8 +236,6 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector input_strings_; // for input }; @@ -190,7 +245,7 @@ class Tensor : public TensorBase { using strings = std::vector; using string_views = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -211,16 +266,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const string_views& Data() const { return input_string_views_; } + const void* DataRaw() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_string_views_[0].data()); + } + size_t SizeInBytes() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -243,16 +303,111 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector chars_; // for input std::vector input_string_views_; // for input }; using TensorPtr = std::unique_ptr; +using TensorPtrs = std::vector; -//////////////////////////// OrtLiteCustomOp //////////////////////////////// +struct TensorArray : public ArgBase { + TensorArray(OrtKernelContext* ctx, + size_t start_indice, + bool is_input) : ArgBase(ctx, + start_indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(start_indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t ith_input) const { + // ith_input is the indice of output relative to the tensor array + return tensors_.at(ith_input); + } + private: + TensorPtrs tensors_; +}; + +using Variadic = TensorArray; + +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -260,34 +415,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { // CreateTuple template static typename std::enable_if>::type - CreateTuple(OrtKernelContext*, std::vector&, size_t, size_t, const std::string&) { + CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { return std::make_tuple(); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{*context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #ifdef ORT_CUDA_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local CudaContext cuda_context; cuda_context.Init(*context); std::tuple current = std::tuple{cuda_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif @@ -295,143 +450,179 @@ struct OrtLiteCustomOp : public OrtCustomOp { #ifdef ORT_ROCM_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local RocmContext rocm_context; rocm_context.Init(*context); std::tuple current = std::tuple{rocm_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif -#define CREATE_TUPLE_INPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } -#define CREATE_TUPLE_OUTPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ } #define CREATE_TUPLE(data_type) \ CREATE_TUPLE_INPUT(data_type) \ @@ -491,6 +682,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + #define PARSE_INPUT_BASE(pack_type, onnx_type) \ template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ @@ -499,6 +718,12 @@ struct OrtLiteCustomOp : public OrtCustomOp { ParseArgs(input_types, output_types); \ } \ template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ ParseArgs(std::vector& input_types, std::vector& output_types) { \ input_types.push_back(onnx_type); \ @@ -557,8 +782,14 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -585,18 +816,52 @@ struct OrtLiteCustomOp : public OrtCustomOp { return self->output_types_[indice]; }; - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + + OrtCustomOp::CreateKernelV2 = {}; + OrtCustomOp::KernelComputeV2 = {}; + OrtCustomOp::KernelCompute = {}; + + OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -604,6 +869,14 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + ShapeInferFn shape_infer_fn_ = {}; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -619,31 +892,37 @@ struct OrtLiteCustomOp : public OrtCustomOp { template struct OrtLiteCustomFunc : public OrtLiteCustomOp { using ComputeFn = void (*)(Args...); + using ComputeFnReturnStatus = Status (*)(Args...); using MyType = OrtLiteCustomFunc; struct Kernel { size_t num_input_{}; size_t num_output_{}; ComputeFn compute_fn_{}; + ComputeFnReturnStatus compute_fn_return_status_{}; std::string ep_{}; }; OrtLiteCustomFunc(const char* op_name, const char* execution_provider, - ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn) { + ComputeFn compute_fn, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); }; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -654,9 +933,55 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast(op_kernel); }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } } - ComputeFn compute_fn_; + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFnReturnStatus compute_fn_return_status, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } + } }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -679,6 +1004,10 @@ template struct OrtLiteCustomStruct : public OrtLiteCustomOp { template using CustomComputeFn = void (CustomOp::*)(Args...); + + template + using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); + using MyType = OrtLiteCustomStruct; struct Kernel { @@ -689,21 +1018,10 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { - init(&CustomOp::Compute); - } - - template - void init(CustomComputeFn) { - ParseArgs(input_types_, output_types_); - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); - std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); - }; + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { + SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); @@ -718,6 +1036,44 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast(op_kernel); }; + + SetShapeInfer(0); + } + + template + void SetCompute(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + } + + template + void SetCompute(CustomComputeFnReturnStatus) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); + }; + } + + template + decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInfer(...) { + OrtCustomOp::InferOutputShapeFn = {}; } }; // struct OrtLiteCustomStruct @@ -726,17 +1082,33 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, - void (*custom_compute_fn)(Args...)) { + void (*custom_compute_fn)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); +} + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + Status (*custom_compute_fn_v2)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom -} // namespace Ort +} // namespace Ort \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 37545f41b43dd..4628afbb5a702 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,16 +67,26 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of -// . +// . // For example, "Gelu+Cast+:1:0,Dropout+:1:1". // A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. // "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. // "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" // the memory. -static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer"; +static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; // Specifies the level for detecting subgraphs for memory footprint reduction. // The value should be an integer. The default value is 0. diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 09d2cefbb8224..0078adb6402f8 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -13,6 +13,7 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import java.nio.ShortBuffer; +import java.util.Optional; /** * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be @@ -21,18 +22,60 @@ public class OnnxTensor extends OnnxTensorLike { /** - * This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does + * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does * not go out of scope while the OnnxTensor exists. */ private final Buffer buffer; + /** + * Denotes if the OnnxTensor made a copy of the buffer on construction (i.e. it may have the only + * reference). + */ + private final boolean ownsBuffer; + OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) { - this(nativeHandle, allocatorHandle, info, null); + this(nativeHandle, allocatorHandle, info, null, false); } - OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) { + OnnxTensor( + long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) { super(nativeHandle, allocatorHandle, info); this.buffer = buffer; + this.ownsBuffer = ownsBuffer; + } + + /** + * Returns true if the buffer in this OnnxTensor was created on construction of this tensor, i.e., + * it is a copy of a user supplied buffer or array and may hold the only reference to that buffer. + * + *

When this is true the backing buffer was copied from the user input, so users cannot mutate + * the state of this buffer without first getting the reference via {@link #getBufferRef()}. + * + * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is + * a copy of a user buffer.) + */ + public boolean ownsBuffer() { + return this.ownsBuffer; + } + + /** + * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not + * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by + * ORT) this method returns an empty {@link Optional}. + * + *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be + * used to repeatedly update a single tensor for multiple different inferences without allocating + * new tensors, though the inputs must remain the same size and shape. + * + *

Note: the tensor could refer to a contiguous range of elements in this buffer, not the whole + * buffer. It is up to the user to manage this information by respecting the position and limit. + * As a consequence, accessing this reference should be considered problematic when multiple + * threads hold references to the buffer. + * + * @return A reference to the buffer. + */ + public Optional getBufferRef() { + return Optional.ofNullable(buffer); } @Override @@ -45,7 +88,8 @@ public OnnxValueType getType() { * primitives if it has multiple dimensions. * *

Java multidimensional arrays are quite slow for more than 2 dimensions, in that case it is - * recommended you use the java.nio.Buffer extractors below (e.g. {@link #getFloatBuffer}). + * recommended you use the {@link java.nio.Buffer} extractors below (e.g., {@link + * #getFloatBuffer}). * * @return A Java value. * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the @@ -283,6 +327,12 @@ private native void getArray(long apiHandle, long nativeHandle, Object carrier) * multidimensional array. The shape is inferred from the object using reflection. The default * allocator is used. * + *

Note: Java multidimensional arrays are not dense and this method requires traversing a large + * number of pointers for high dimensional arrays. For types other than Strings it is recommended + * to use one of the {@code createTensor} methods which accepts a {@link java.nio.Buffer}, e.g. + * {@link #createTensor(OrtEnvironment, FloatBuffer, long[])} as those methods are zero copy to + * transfer data into ORT when using direct buffers. + * * @param env The current OrtEnvironment. * @param data The data to store in a tensor. * @return An OnnxTensor storing the data. @@ -700,7 +750,8 @@ private static OnnxTensor createTensor( info.onnxType.value), allocator.handle, info, - tuple.data); + tuple.data, + tuple.isCopy); } private static native long createTensor( diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 435f86daa5fe2..fbea13d155507 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -239,7 +239,7 @@ public Result run(Map inputs, RunOptions runOp */ public Result run(Map inputs, Set requestedOutputs) throws OrtException { - return run(inputs, requestedOutputs, null); + return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** @@ -259,17 +259,90 @@ public Result run( Set requestedOutputs, RunOptions runOptions) throws OrtException { + return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); + } + + /** + * Scores an input feed dict, returning the map of pinned outputs. + * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, Map pinnedOutputs) + throws OrtException { + return run(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs) + throws OrtException { + return run(inputs, requestedOutputs, pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs, + RunOptions runOptions) + throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { throw new OrtException( "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > numOutputs)) { + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -284,20 +357,41 @@ public Result run( "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (outputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + } + } for (String s : requestedOutputs) { if (outputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + + s + + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - - OnnxValue[] outputValues = + boolean[] ownedByResult = run( OnnxRuntime.ortApiHandle, nativeHandle, @@ -307,13 +401,40 @@ public Result run( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new Result(outputNamesArray, outputValues); + return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); } } + /** + * Pulls out the native handle by casting it to the appropriate type. + * + * @param v The OnnxValue. + * @return The native handle. + */ + static long getHandle(OnnxValue v) { + /* + * Note this method exists as interface methods are all public, but we do not want users to be + * able to access the native pointer via a public API so can't add a method to OnnxValue which + * exposes it. + */ + if (v instanceof OnnxTensorLike) { + return ((OnnxTensorLike) v).nativeHandle; + } else if (v instanceof OnnxSequence) { + return ((OnnxSequence) v).nativeHandle; + } else if (v instanceof OnnxMap) { + return ((OnnxMap) v).nativeHandle; + } else { + throw new IllegalArgumentException( + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + + v.getClass()); + } + } + /** * Gets the metadata for the currently loaded model. * @@ -409,8 +530,9 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other - * handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can + * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but + * all other handles must be valid pointers. * * @param apiHandle The pointer to the api. * @param nativeHandle The pointer to the session. @@ -419,12 +541,14 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long * @param inputs The input tensors. * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return The OnnxValues produced by this run. + * @return A boolean array representing if the OnnxValues were allocated by this run call. * @throws OrtException If the native call failed in some way. */ - private native OnnxValue[] run( + private native boolean[] run( long apiHandle, long nativeHandle, long allocatorHandle, @@ -433,6 +557,8 @@ private native OnnxValue[] run( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; @@ -1417,9 +1543,13 @@ private native void addRunConfigEntry( /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * - *

When this is closed it closes all the {@link OnnxValue}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link + *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you + * maintain a reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. + * + *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed + * by the {@link Result#close()} method. Ownership of each output can be checked with {@link + * Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1429,6 +1559,8 @@ public static class Result implements AutoCloseable, Iterable list; + private final boolean[] ownedByResult; + private boolean closed; /** @@ -1437,21 +1569,23 @@ public static class Result implements AutoCloseable, IterableThrows {@link IllegalStateException} if the container has been closed, and {@link + * ArrayIndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return Is that value owned by this result object? + */ + public boolean isResultOwner(int index) { + if (!closed) { + return ownedByResult[index]; + } else { + throw new IllegalStateException("Result is closed"); + } + } + /** * Returns the number of outputs in this Result. * diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 8c03c5b80433c..49ddf29c22335 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -418,7 +418,7 @@ private static native void setSeed(long apiHandle, long trainingHandle, long see */ public OrtSession.Result trainStep(Map inputs) throws OrtException { - return trainStep(inputs, trainOutputNames, null); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), null); } /** @@ -432,7 +432,7 @@ public OrtSession.Result trainStep(Map inputs) public OrtSession.Result trainStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return trainStep(inputs, trainOutputNames, runOptions); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), runOptions); } /** @@ -446,14 +446,41 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs) throws OrtException { - return trainStep(inputs, requestedOutputs, null); + return trainStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single step of training, accumulating the gradients. * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * * @param inputs The inputs (must include both the features and the target). - * @param requestedOutputs The requested outputs. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return Requested outputs produced by the training step. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result trainStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return trainStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single step of training, accumulating the gradients. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs (must include both the features and the target). + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return Requested outputs produced by the training step. * @throws OrtException If the native call failed. @@ -461,6 +488,7 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -472,12 +500,14 @@ public OrtSession.Result trainStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > trainOutputNames.size())) { + int numTrainOutputs = trainOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numTrainOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + trainOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numTrainOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -492,12 +522,35 @@ public OrtSession.Result trainStep( "Unknown input name " + t.getKey() + ", expected one of " + trainInputNames); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (trainOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + trainOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (trainOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + trainOutputNames.toString()); @@ -505,7 +558,7 @@ public OrtSession.Result trainStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = trainStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -516,8 +569,10 @@ public OrtSession.Result trainStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -540,7 +595,7 @@ public OrtSession.Result trainStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] trainStep( + private native boolean[] trainStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -550,6 +605,8 @@ private native OnnxValue[] trainStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle); /** @@ -561,7 +618,7 @@ private native OnnxValue[] trainStep( */ public OrtSession.Result evalStep(Map inputs) throws OrtException { - return evalStep(inputs, evalOutputNames, null); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), null); } /** @@ -575,7 +632,7 @@ public OrtSession.Result evalStep(Map inputs) public OrtSession.Result evalStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return evalStep(inputs, evalOutputNames, runOptions); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), runOptions); } /** @@ -589,14 +646,41 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs) throws OrtException { - return evalStep(inputs, requestedOutputs, null); + return evalStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single evaluation step using the supplied inputs. * - * @param inputs The model inputs. - * @param requestedOutputs The requested output names. + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The requested outputs. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result evalStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return evalStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single evaluation step using the supplied inputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return The requested outputs. * @throws OrtException If the native call failed. @@ -604,6 +688,7 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -615,12 +700,14 @@ public OrtSession.Result evalStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > evalOutputNames.size())) { + int numEvalOutputs = evalOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numEvalOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + evalOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numEvalOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -635,12 +722,35 @@ public OrtSession.Result evalStep( "Unknown input name " + t.getKey() + ", expected one of " + evalInputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (evalOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + evalOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (evalOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + evalOutputNames.toString()); @@ -648,7 +758,7 @@ public OrtSession.Result evalStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = evalStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -659,8 +769,10 @@ public OrtSession.Result evalStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -682,7 +794,7 @@ public OrtSession.Result evalStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] evalStep( + private native boolean[] evalStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -692,6 +804,8 @@ private native OnnxValue[] evalStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 34d635b5c40f8..69ccb954e8afe 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -58,14 +58,37 @@ public enum OnnxTensorType { */ ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16( 16), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN( - 17), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ( - 18), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2( - 19), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ( - 20); // Non-IEEE floating-point format based on IEEE754 single-precision + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN and + * no infinite values (FN). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN(17), + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ(18), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits. + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2(19), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ(20); /** The int id on the native side. */ public final int value; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6f4e34648cf81..f4d5ab080cd31 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -316,14 +316,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIE /* * Class: ai_onnxruntime_OrtSession * Method: run - * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; - * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) + * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z + * private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, + * String[] inputNamesArray, long[] inputs, long numInputs, + * String[] outputNamesArray, long numOutputs, + * OnnxValue[] outputValues, long[] outputHandles, + * long runOptionsHandle) throws OrtException; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, - jlong numOutputs, jlong runOptionsHandle) { + jlong numOutputs, jobjectArray outputValuesArr, + jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; @@ -331,7 +336,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv OrtSession* session = (OrtSession*)sessionHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers const char** inputNames = allocarray(numInputs, sizeof(char*)); @@ -376,13 +381,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. // ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, // _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, @@ -394,21 +405,26 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. cleanup_output_values: diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index d3239c7442c80..3a1c0d1bb8fa1 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -19,7 +19,6 @@ #include "onnxruntime/core/providers/nnapi/nnapi_provider_factory.h" #include "onnxruntime/core/providers/tvm/tvm_provider_factory.h" #include "onnxruntime/core/providers/openvino/openvino_provider_factory.h" -#include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h" #include "onnxruntime/core/providers/acl/acl_provider_factory.h" #include "onnxruntime/core/providers/armnn/armnn_provider_factory.h" #include "onnxruntime/core/providers/coreml/coreml_provider_factory.h" diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index b3b530a8b15aa..9f7b8d3a3dcfc 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -330,12 +330,12 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad /* * Class: ai_onnxruntime_OrtTrainingSession * Method: trainStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -343,31 +343,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -388,13 +388,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. //ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, // size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, @@ -406,24 +412,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -437,15 +448,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; @@ -454,12 +465,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep /* * Class: ai_onnxruntime_OrtTrainingSession * Method: evalStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -467,31 +478,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -512,11 +523,14 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } // Actually score the inputs. @@ -530,24 +544,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -561,15 +580,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 08d2a5698d579..e975117fb75bd 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -6,11 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import ai.onnxruntime.OrtException.OrtErrorCode; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; @@ -31,6 +34,8 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -71,7 +76,7 @@ public void environmentTest() { @Test public void testVersion() { String version = env.getVersion(); - Assertions.assertFalse(version.isEmpty()); + assertFalse(version.isEmpty()); } @Test @@ -749,6 +754,151 @@ public void testOverridingInitializer() throws OrtException { } } + @Test + public void testPinnedOutputs() throws OrtException { + String modelPath = TestHelpers.getResourcePath("/java-three-output-matmul.onnx").toString(); + FloatBuffer outputABuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputBBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputCBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooSmallBuf = + ByteBuffer.allocateDirect(4 * 2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooBigBuf = + ByteBuffer.allocateDirect(4 * 6).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer wrongShapeBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + LongBuffer wrongTypeBuf = + ByteBuffer.allocateDirect(8 * 4).order(ByteOrder.nativeOrder()).asLongBuffer(); + + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options); + OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}}); + OnnxTensor outputA = OnnxTensor.createTensor(env, outputABuf, new long[] {1, 4}); + OnnxTensor outputB = OnnxTensor.createTensor(env, outputBBuf, new long[] {1, 4}); + OnnxTensor outputC = OnnxTensor.createTensor(env, outputCBuf, new long[] {1, 4}); + OnnxTensor tooSmall = OnnxTensor.createTensor(env, tooSmallBuf, new long[] {1, 2}); + OnnxTensor tooBig = OnnxTensor.createTensor(env, tooBigBuf, new long[] {1, 6}); + OnnxTensor wrongShape = OnnxTensor.createTensor(env, wrongShapeBuf, new long[] {2, 2}); + OnnxTensor wrongType = OnnxTensor.createTensor(env, wrongTypeBuf, new long[] {1, 4})) { + Map inputMap = Collections.singletonMap("input", t); + Set requestedOutputs = new LinkedHashSet<>(); + Map pinnedOutputs = new LinkedHashMap<>(); + + // Test that all outputs can be pinned + pinnedOutputs.put("output-0", outputA); + pinnedOutputs.put("output-1", outputB); + pinnedOutputs.put("output-2", outputC); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + assertSame(outputA, r.get(0)); + assertSame(outputB, r.get(1)); + assertSame(outputC, r.get(2)); + assertFalse(r.isResultOwner(0)); + assertFalse(r.isResultOwner(1)); + assertFalse(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a single pinned output + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(1, r.size()); + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + assertFalse(r.isResultOwner(0)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a mixture of pinned and generated outputs + requestedOutputs.add("output-0"); + requestedOutputs.add("output-2"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + // pinned outputs are first + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + // requested outputs are different + assertNotSame(outputA, r.get("output-0").get()); + assertNotSame(outputC, r.get("output-2").get()); + // check ownership. + assertFalse(r.isResultOwner(0)); + assertTrue(r.isResultOwner(1)); + assertTrue(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that overlapping names causes an error + requestedOutputs.add("output-1"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_JAVA_UNKNOWN, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong type causes an error + pinnedOutputs.put("output-0", wrongType); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong shape (but right capacity) causes an error. + pinnedOutputs.put("output-1", wrongShape); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too small causes an error + pinnedOutputs.put("output-1", tooSmall); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too large causes an error + pinnedOutputs.put("output-1", tooBig); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + } + } + } + private static File getTestModelsDir() throws IOException { // get build directory, append downloaded models location String cwd = System.getProperty("user.dir"); diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java index 90fda4c5cf610..7bf7cef43208a 100644 --- a/java/src/test/java/ai/onnxruntime/ModelGenerators.java +++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java @@ -182,6 +182,102 @@ public void generateMatMul() throws IOException { } } + public void generateThreeOutputMatmul() throws IOException { + OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); + graph.setName("ort-test-three-matmul"); + + // Add placeholders + OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder(); + input.setName("input"); + OnnxMl.TypeProto inputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + input.setType(inputType); + graph.addInput(input); + OnnxMl.ValueInfoProto.Builder outputA = OnnxMl.ValueInfoProto.newBuilder(); + outputA.setName("output-0"); + OnnxMl.TypeProto outputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + outputA.setType(outputType); + graph.addOutput(outputA); + OnnxMl.ValueInfoProto.Builder outputB = OnnxMl.ValueInfoProto.newBuilder(); + outputB.setName("output-1"); + outputB.setType(outputType); + graph.addOutput(outputB); + OnnxMl.ValueInfoProto.Builder outputC = OnnxMl.ValueInfoProto.newBuilder(); + outputC.setName("output-2"); + outputC.setType(outputType); + graph.addOutput(outputC); + + // Add initializers + OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder(); + tensor.addDims(4); + tensor.addDims(4); + Float[] floats = + new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f}; + tensor.addAllFloatData(Arrays.asList(floats)); + tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + tensor.setName("tensor"); + graph.addInitializer(tensor); + OnnxMl.TensorProto.Builder addInit = OnnxMl.TensorProto.newBuilder(); + addInit.addDims(4); + Float[] addFloats = new Float[] {1f, 2f, 3f, 4f}; + addInit.addAllFloatData(Arrays.asList(addFloats)); + addInit.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + addInit.setName("add-init"); + graph.addInitializer(addInit); + + // Add operations + OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder(); + matmul.setName("matmul-0"); + matmul.setOpType("MatMul"); + matmul.addInput("input"); + matmul.addInput("tensor"); + matmul.addOutput("matmul-output"); + graph.addNode(matmul); + + OnnxMl.NodeProto.Builder id = OnnxMl.NodeProto.newBuilder(); + id.setName("id-1"); + id.setOpType("Identity"); + id.addInput("matmul-output"); + id.addOutput("output-0"); + graph.addNode(id); + + OnnxMl.NodeProto.Builder add = OnnxMl.NodeProto.newBuilder(); + add.setName("add-2"); + add.setOpType("Add"); + add.addInput("matmul-output"); + add.addInput("add-init"); + add.addOutput("output-1"); + graph.addNode(add); + + OnnxMl.NodeProto.Builder log = OnnxMl.NodeProto.newBuilder(); + log.setName("log-3"); + log.setOpType("Log"); + log.addInput("matmul-output"); + log.addOutput("output-2"); + graph.addNode(log); + + // Build model + OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder(); + model.setGraph(graph); + model.setDocString("ORT three output matmul test"); + model.setModelVersion(0); + model.setIrVersion(8); + model.setDomain("ai.onnxruntime.test"); + model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build()); + try (OutputStream os = + Files.newOutputStream( + Paths.get("src", "test", "resources", "java-three-output-matmul.onnx"))) { + model.build().writeTo(os); + } + } + private static void genCast( String name, OnnxMl.TensorProto.DataType inputDataType, diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index fcb4590717fea..a5f285ba86a14 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -88,6 +88,91 @@ public void testScalarCreation() throws OrtException { } } + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + + // Test creating a value from an array + // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + float[] arrValues = new float[] {0, 1, 2, 3, 4}; + try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { + // array creation isn't backed by buffers + Assertions.assertFalse(t.ownsBuffer()); + Assertions.assertFalse(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through this buffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + } + + // Test creating a value from a non-direct byte buffer + // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap + // direct byte buffers + // which can be directly passed to ORT + FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); + nonDirectBuffer.put(arrValues); + nonDirectBuffer.rewind(); + try (OnnxTensor t = OnnxTensor.createTensor(env, nonDirectBuffer, new long[] {1, 5})) { + // non-direct buffers trigger a copy + Assertions.assertTrue(t.ownsBuffer()); + // tensors backed by buffers can get the buffer ref back out + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through getFloatBuffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + + // Can modify the tensor through getBufferRef. + FloatBuffer ref = (FloatBuffer) t.getBufferRef().get(); + ref.put(0, 25); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(25, output[0]); + } + + // Test creating a value from a direct byte buffer + // Direct byte buffers can be passed into ORT without additional copies or processing + FloatBuffer directBuffer = + ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + directBuffer.put(arrValues); + directBuffer.rewind(); + try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) { + // direct buffers don't trigger a copy + Assertions.assertFalse(t.ownsBuffer()); + // tensors backed by buffers can get the buffer ref back out + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through getFloatBuffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + + // Can modify the tensor through getBufferRef. + FloatBuffer ref = (FloatBuffer) t.getBufferRef().get(); + ref.put(0, 25); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(25, output[0]); + + // Can modify the tensor through our original ref to the direct byte buffer + directBuffer.put(1, 15); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(15, output[1]); + } + } + @Test public void testStringCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 7d41918b1c6c7..55d8169434d48 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -262,6 +262,12 @@ public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } + public static void zeroBuffer(FloatBuffer buf) { + for (int i = 0; i < buf.capacity(); i++) { + buf.put(i, 0.0f); + } + } + public static float[] loadTensorFromFile(Path filename) { return loadTensorFromFile(filename, true); } diff --git a/java/src/test/java/ai/onnxruntime/TrainingTest.java b/java/src/test/java/ai/onnxruntime/TrainingTest.java index a02f5a88b2ac5..eaa7da1fc6a16 100644 --- a/java/src/test/java/ai/onnxruntime/TrainingTest.java +++ b/java/src/test/java/ai/onnxruntime/TrainingTest.java @@ -16,7 +16,6 @@ import java.util.Map; import java.util.Set; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; @@ -69,8 +68,6 @@ public void testCreateTrainingSessionWithEval() throws OrtException { } } - // this test is not enabled as ORT Java doesn't support supplying an output buffer - @Disabled @Test public void testTrainingSessionTrainStep() throws OrtException { String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString(); @@ -99,14 +96,11 @@ public void testTrainingSessionTrainStep() throws OrtException { ByteBuffer.allocateDirect(4 * expectedOutput.length) .order(ByteOrder.nativeOrder()) .asFloatBuffer(); - OnnxTensor outputTensor = - OnnxTensor.createTensor(env, output, new long[expectedOutput.length]); + OnnxTensor outputTensor = OnnxTensor.createTensor(env, output, new long[0]); outputMap.put("onnx::loss::21273", outputTensor); - /* Disabled as we haven't implemented this yet - try (trainingSession.trainStep(pinnedInputs, outputMap)) { - Assertions.assertArrayEquals(expectedOutput, (float[]) outputTensor.getValue(), 1e-3f); + try (OrtSession.Result r = trainingSession.trainStep(pinnedInputs, outputMap)) { + Assertions.assertEquals(expectedOutput[0], (float) outputTensor.getValue(), 1e-3f); } - */ } finally { OnnxValue.close(outputMap); OnnxValue.close(pinnedInputs); diff --git a/java/src/test/resources/java-three-output-matmul.onnx b/java/src/test/resources/java-three-output-matmul.onnx new file mode 100644 index 0000000000000..fed0bbca460cf Binary files /dev/null and b/java/src/test/resources/java-three-output-matmul.onnx differ diff --git a/js/.eslintrc.js b/js/.eslintrc.js index e13cabae9ed45..0bf47c5264f61 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -5,10 +5,18 @@ module.exports = { root: true, - ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'node_modules/', 'types/', 'dist/'], + ignorePatterns: [ + '**/*.js', + 'node_modules/', + 'ort-schema/', + 'common/test/type-tests/', + 'web/types.d.ts', + 'test/data/', + 'dist/', + ], env: { 'es6': true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' }, + parserOptions: { 'project': true, 'sourceType': 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', @@ -119,7 +127,6 @@ module.exports = { rules: { 'jsdoc/check-alignment': 'error', 'jsdoc/check-indentation': 'error', - 'jsdoc/newline-after-description': 'error', } }, { files: ['common/test/**/*.ts'], @@ -146,7 +153,54 @@ module.exports = { } }, { files: ['web/lib/**/*.ts'], rules: { - 'no-underscore-dangle': 'off', + 'no-underscore-dangle': ['error', { + 'allow': [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep' + ] + }] } }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { @@ -159,6 +213,7 @@ module.exports = { 'import/no-internal-modules': 'off', 'prefer-arrow/prefer-arrow-functions': 'off', 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', 'guard-for-in': 'off' } }, { diff --git a/js/.gitignore b/js/.gitignore index 076984956d0b8..028fde4c17af9 100644 --- a/js/.gitignore +++ b/js/.gitignore @@ -9,3 +9,5 @@ tsconfig.tsbuildinfo *.d.ts *.tgz + +*.esbuild.metafile.json diff --git a/js/.vscode/launch.json b/js/.vscode/launch.json index 3f4ec74b7de58..26a08d37488ba 100644 --- a/js/.vscode/launch.json +++ b/js/.vscode/launch.json @@ -5,7 +5,7 @@ "version": "0.2.0", "configurations": [ { - "name": "[common] Launch Unit Tests", + "name": "[common] Launch Unit Tests in Node.js", "args": ["-u", "bdd", "--timeout", "999999", "--colors", "${workspaceFolder}/common/test/**/*.js"], "internalConsoleOptions": "openOnSessionStart", "program": "${workspaceFolder}/node_modules/mocha/bin/_mocha", @@ -17,7 +17,7 @@ "preLaunchTask": "tsc: build - common/test/tsconfig.json" }, { - "name": "[web] Launch Test Runner", + "name": "[web] Launch Test Runner CLI in Node.js", "program": "${workspaceFolder}/web/script/test-runner-cli.js", "request": "launch", "skipFiles": ["/**"], @@ -26,17 +26,35 @@ "args": ["suite1"] }, { - "name": "[web] Attach to Chrome", + "name": "[web] Launch NPM tests in Node.js", + "args": [ + "--timeout", + "999999", + "--colors", + "-r", + "${workspaceFolder}/web/dist/ort.node.min.js", + "${workspaceFolder}/web/test/test-main" + ], + "internalConsoleOptions": "openOnSessionStart", + "program": "${workspaceFolder}/node_modules/mocha/bin/_mocha", + "request": "launch", + "skipFiles": ["/**"], + "type": "node", + "cwd": "${workspaceFolder}/web" + }, + { + "name": "[web] Attach to Chrome for NPM tests", "type": "chrome", "request": "attach", "port": 9333, "webRoot": "${workspaceFolder}", "sourceMapPathOverrides": { - "webpack://ort/*": "${webRoot}/common/*", - "webpack:///*": "${webRoot}/web/*" + "../../common/*": "${webRoot}/common/*", + "../lib/*": "${webRoot}/web/lib/*" }, "sourceMaps": true, - "smartStep": true + "smartStep": true, + "skipFiles": ["**/node_modules/**"] }, { "name": "[web] Remote Browser via Webkit Adaptor", diff --git a/js/.vscode/settings.json b/js/.vscode/settings.json index 4948899ec671b..9c2fe646d728d 100644 --- a/js/.vscode/settings.json +++ b/js/.vscode/settings.json @@ -36,7 +36,8 @@ "node/lib/**/*.js.map": true, "node/lib/**/*.js": true, "web/lib/**/*.js.map": true, - "web/lib/**/*.js": true + "web/lib/**/*.js": true, + "web/lib/**/*.d.ts": true }, "files.insertFinalNewline": true, "files.trimTrailingWhitespace": true, diff --git a/js/README.md b/js/README.md index 7e6681e6bd897..1662de6d4ac78 100644 --- a/js/README.md +++ b/js/README.md @@ -344,13 +344,13 @@ From ORT v1.13 onwards the 'full' ONNX Runtime package is used. It supports both Full build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_full_ios_framework_build_settings.json --config Release + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_full_apple_framework_build_settings.json --config Release ``` Reduced size build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support ``` The build creates `Headers`, `LICENSE`, and `onnxruntime.xcframework` in `build/iOS_framework/framework_out` directory. From `framework_out` directory, create an archive file named `onnxruntime-c.zip` for a full build or `onnxruntime-mobile-c.zip` for a reduced size build and copy to `/js/react_native/local_pods` directory. diff --git a/js/build_jsep.bat b/js/build_jsep.bat index 02f1170ecb067..acd40ff920774 100644 --- a/js/build_jsep.bat +++ b/js/build_jsep.bat @@ -22,7 +22,7 @@ if ["%~1"]==["d"] ( ) if ["%~1"]==["r"] ( set CONFIG=Release - set CONFIG_EXTRA_FLAG= + set CONFIG_EXTRA_FLAG=--enable_wasm_api_exception_catching --disable_rtti goto :arg2 ) echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release) diff --git a/js/common/build.js b/js/common/build.js index cf459f5efa812..b0956c608b350 100644 --- a/js/common/build.js +++ b/js/common/build.js @@ -21,6 +21,3 @@ execSync('npm run build:esm', {shell: true, stdio: 'inherit', cwd: __dirname}); // see also: https://evertpot.com/universal-commonjs-esm-typescript-packages/ writeFileSync(resolve(__dirname, './dist/cjs', 'package.json'), '{"type": "commonjs"}'); writeFileSync(resolve(__dirname, './dist/esm', 'package.json'), '{"type": "module"}'); - -// launch webpack to generate bundles -execSync('npm run build:bundles', {shell: true, stdio: 'inherit', cwd: __dirname}); diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 75feba1d0ae08..e129c6971a85c 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -26,7 +26,7 @@ const backendsSortedByPriority: string[] = []; * @ignore */ export const registerBackend = (name: string, backend: Backend, priority: number): void => { - if (backend && typeof backend.init === 'function' && typeof backend.createSessionHandler === 'function') { + if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') { const currentBackend = backends.get(name); if (currentBackend === undefined) { backends.set(name, {backend, priority}); diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 804f33f00d103..67d283b694955 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,6 +3,7 @@ import {InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; +import {TrainingSession} from './training-session.js'; /** * @ignore @@ -14,16 +15,23 @@ export declare namespace SessionHandler { } /** - * Represent a handler instance of an inference session. + * Represents shared SessionHandler functionality * * @ignore */ -export interface SessionHandler { +interface SessionHandler { dispose(): Promise; readonly inputNames: readonly string[]; readonly outputNames: readonly string[]; +} +/** + * Represent a handler instance of an inference session. + * + * @ignore + */ +export interface InferenceSessionHandler extends SessionHandler { startProfiling(): void; endProfiling(): void; @@ -31,6 +39,21 @@ export interface SessionHandler { options: InferenceSession.RunOptions): Promise; } +/** + * Represent a handler instance of a training inference session. + * + * @ignore + */ +export interface TrainingSessionHandler extends SessionHandler { + runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; + + getParametersSize(trainableOnly: boolean): Promise; + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; +} + /** * Represent a backend that provides implementation of model inferencing. * @@ -42,8 +65,13 @@ export interface Backend { */ init(): Promise; - createSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise; + createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + + createTrainingSessionHandler? + (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, + evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + options: InferenceSession.SessionOptions): Promise; } export {registerBackend} from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 525272294c587..76575ef7b9368 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -9,6 +9,7 @@ export declare namespace Env { 'ort-wasm.wasm'?: string; 'ort-wasm-threaded.wasm'?: string; 'ort-wasm-simd.wasm'?: string; + 'ort-training-wasm-simd.wasm'?: string; 'ort-wasm-simd-threaded.wasm'?: string; /* eslint-enable @typescript-eslint/naming-convention */ }; @@ -105,6 +106,12 @@ export declare namespace Env { * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; + /** + * Set or get whether validate input content. + * + * @defaultValue `false` + */ + validateInputContent?: boolean; } } diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 85df1747f8576..9cbfcc4e8bcdc 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -22,3 +22,4 @@ export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; export * from './onnx-value.js'; +export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 06949b4a26c0d..9bc2088f2088a 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {SessionHandler} from './backend.js'; +import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; import {Tensor} from './tensor.js'; @@ -14,7 +14,7 @@ type FetchesType = InferenceSessionInterface.FetchesType; type ReturnType = InferenceSessionInterface.ReturnType; export class InferenceSession implements InferenceSessionInterface { - private constructor(handler: SessionHandler) { + private constructor(handler: InferenceSessionHandler) { this.handler = handler; } run(feeds: FeedsType, options?: RunOptions): Promise; @@ -195,7 +195,7 @@ export class InferenceSession implements InferenceSessionInterface { const eps = options.executionProviders || []; const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); - const handler = await backend.createSessionHandler(filePathOrUint8Array, options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); return new InferenceSession(handler); } @@ -213,5 +213,5 @@ export class InferenceSession implements InferenceSessionInterface { return this.handler.outputNames; } - private handler: SessionHandler; + private handler: InferenceSessionHandler; } diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 71a5912df2464..c7760692eed00 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -192,6 +192,7 @@ export declare namespace InferenceSession { wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; xnnpack: XnnpackExecutionProviderOption; + webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; nnapi: NnapiExecutionProviderOption; } @@ -233,9 +234,14 @@ export declare namespace InferenceSession { export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { readonly name: 'xnnpack'; } + export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'webgpu'; + preferredLayout?: 'NCHW'|'NHWC'; + } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; deviceType?: 'cpu'|'gpu'; + numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts new file mode 100644 index 0000000000000..03694738387f2 --- /dev/null +++ b/js/common/lib/training-session-impl.ts @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {resolveBackend} from './backend-impl.js'; +import {SessionHandler, TrainingSessionHandler} from './backend.js'; +import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; +import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; + +type SessionOptions = InferenceSession.SessionOptions; +type FeedsType = InferenceSession.FeedsType; +type FetchesType = InferenceSession.FetchesType; +type ReturnType = InferenceSession.ReturnType; +type RunOptions = InferenceSession.RunOptions; + +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; + +export class TrainingSession implements TrainingSessionInterface { + private constructor(handler: TrainingSessionHandler) { + this.handler = handler; + } + private handler: TrainingSessionHandler; + + get inputNames(): readonly string[] { + return this.handler.inputNames; + } + get outputNames(): readonly string[] { + return this.handler.outputNames; + } + + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): + Promise { + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { + throw new Error(noBackendErrMsg); + } + } + + /** + * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from + * the given parameters to SessionHandler.FetchesType and RunOptions. + * + * @param feeds the required input + * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object + * @param arg2 optional RunOptions object. + * @returns + */ + typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): + [SessionHandler.FetchesType, RunOptions] { + const fetches: {[name: string]: OnnxValue|null} = {}; + let options: RunOptions = {}; + // check inputs + if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { + throw new TypeError( + '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + } + + let isFetchesEmpty = true; + // determine which override is being used + if (typeof arg1 === 'object') { + if (arg1 === null) { + throw new TypeError('Unexpected argument[1]: cannot be null.'); + } + if (arg1 instanceof Tensor) { + throw new TypeError('\'fetches\' cannot be a Tensor'); + } + + if (Array.isArray(arg1)) { + if (arg1.length === 0) { + throw new TypeError('\'fetches\' cannot be an empty array.'); + } + isFetchesEmpty = false; + // output names + for (const name of arg1) { + if (typeof name !== 'string') { + throw new TypeError('\'fetches\' must be a string array or an object.'); + } + if (this.outputNames.indexOf(name) === -1) { + throw new RangeError(`'fetches' contains invalid output name: ${name}.`); + } + fetches[name] = null; + } + + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + // decide whether arg1 is fetches or options + // if any output name is present and its value is valid OnnxValue, we consider it fetches + let isFetches = false; + const arg1Keys = Object.getOwnPropertyNames(arg1); + for (const name of this.outputNames) { + if (arg1Keys.indexOf(name) !== -1) { + const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; + if (v === null || v instanceof Tensor) { + isFetches = true; + isFetchesEmpty = false; + fetches[name] = v; + } + } + } + + if (isFetches) { + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + options = arg1 as RunOptions; + } + } + } else if (typeof arg1 !== 'undefined') { + throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + } + + // check if all inputs are in feed + for (const name of this.inputNames) { + if (typeof feeds[name] === 'undefined') { + throw new Error(`input '${name}' is missing in 'feeds'.`); + } + } + + // if no fetches is specified, we use the full output names list + if (isFetchesEmpty) { + for (const name of this.outputNames) { + fetches[name] = null; + } + } + + return [fetches, options]; + } + + /** + * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler + * and changes it into a map of Tensors. + * + * @param results + * @returns + */ + convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { + const returnValue: {[name: string]: OnnxValue} = {}; + for (const key in results) { + if (Object.hasOwnProperty.call(results, key)) { + const result = results[key]; + if (result instanceof Tensor) { + returnValue[key] = result; + } else { + returnValue[key] = new Tensor(result.type, result.data, result.dims); + } + } + } + return returnValue; + } + + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const results = await this.handler.runTrainStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } + + async getParametersSize(trainableOnly = true): Promise { + return this.handler.getParametersSize(trainableOnly); + } + + async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { + const paramsSize = await this.getParametersSize(trainableOnly); + // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number + // of parameters + if (array.length !== 4 * paramsSize) { + throw new Error( + 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.'); + } + return this.handler.loadParametersBuffer(array, trainableOnly); + } + + async getContiguousParameters(trainableOnly = true): Promise { + return this.handler.getContiguousParameters(trainableOnly); + } + + async release(): Promise { + return this.handler.dispose(); + } +} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts new file mode 100644 index 0000000000000..810ec2a8583b3 --- /dev/null +++ b/js/common/lib/training-session.ts @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; + +/* eslint-disable @typescript-eslint/no-redeclare */ + +export declare namespace TrainingSession { + /** + * Either URI file path (string) or Uint8Array containing model or checkpoint information. + */ + type URIorBuffer = string|Uint8Array; +} + +/** + * Represent a runtime instance of an ONNX training session, + * which contains a model that can be trained, and, optionally, + * an eval and optimizer model. + */ +export interface TrainingSession { + // #region run() + + /** + * Run TrainStep asynchronously with the given feeds and options. + * + * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for + detail. + * @param options - Optional. A set of options that controls the behavior of model training. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. + */ + runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single train step with the given inputs and options. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model inference. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runTrainStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + + // #endregion + + // #region copy parameters + + /** + * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of + * the parameters) elements of all the parameters in the training state. + * + * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. + */ + getParametersSize(trainableOnly: boolean): Promise; + + /** + * Copies parameter values from the given array to the training state. Currently, only supporting models with + * parameters of type Float32. + * + * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. + */ + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + + /** + * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. + * Currently, only supporting models with parameters of type Float32. + * + * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters + * for which requires_grad is set to true. Default value is true. + * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. + */ + getContiguousParameters(trainableOnly: boolean): Promise; + // #endregion + + // #region release() + + /** + * Release the inference session and the underlying resources. + */ + release(): Promise; + // #endregion + + // #region metadata + + /** + * Get input names of the loaded model. + */ + readonly inputNames: readonly string[]; + + /** + * Get output names of the loaded model. + */ + readonly outputNames: readonly string[]; + // #endregion +} + +/** + * Represents the optional parameters that can be passed into the TrainingSessionFactory. + */ +export interface TrainingSessionCreateOptions { + /** + * URI or buffer for a .ckpt file that contains the checkpoint for the training model. + */ + checkpointState: TrainingSession.URIorBuffer; + /** + * URI or buffer for the .onnx training file. + */ + trainModel: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx optimizer model file. + */ + optimizerModel?: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx eval model file. + */ + evalModel?: TrainingSession.URIorBuffer; +} + +/** + * Defines method overload possibilities for creating a TrainingSession. + */ +export interface TrainingSessionFactory { + // #region create() + + /** + * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions + * + * @param trainingOptions specify models and checkpoints to load into the Training Session + * @param sessionOptions specify configuration for training session behavior + * + * @returns Promise that resolves to a TrainingSession object + */ + create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions): + Promise; + + // #endregion +} + +// eslint-disable-next-line @typescript-eslint/naming-convention +export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/common/tsconfig.json b/js/common/tsconfig.json index 4751cd7564f5d..d7bb3a593f66f 100644 --- a/js/common/tsconfig.json +++ b/js/common/tsconfig.json @@ -4,8 +4,7 @@ "outDir": "./dist/esm", "declaration": true, "declarationMap": true, - "esModuleInterop": false, - "noUnusedParameters": true + "esModuleInterop": false }, "include": ["lib"] } diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index d3680f9d44236..5f5ad49a2dea8 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler} from 'onnxruntime-common'; import {Binding, binding} from './binding'; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { @@ -53,8 +53,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { return new Promise((resolve, reject) => { process.nextTick(() => { try { diff --git a/js/node/package-lock.json b/js/node/package-lock.json index ce390aa88c0aa..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -165,9 +159,9 @@ "dev": true }, "node_modules/axios": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.4.tgz", - "integrity": "sha512-toYm+Bsyl6VC5wSkfkbbNB6ROv7KY93PEBBL6xyDczaIHasAiv4wPqQ/c4RjoQzipxRD2W5g21cOqQulZ7rHwQ==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", + "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", "dev": true, "dependencies": { "follow-redirects": "^1.15.0", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1126,9 +1103,9 @@ "dev": true }, "axios": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.4.tgz", - "integrity": "sha512-toYm+Bsyl6VC5wSkfkbbNB6ROv7KY93PEBBL6xyDczaIHasAiv4wPqQ/c4RjoQzipxRD2W5g21cOqQulZ7rHwQ==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", + "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", "dev": true, "requires": { "follow-redirects": "^1.15.0", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 70e63da7cefa7..a0de832d87fe5 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -16,7 +16,6 @@ #include "core/providers/dml/dml_provider_factory.h" #endif #ifdef USE_TENSORRT -#include "core/providers/tensorrt/tensorrt_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif #ifdef USE_COREML diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/node/tsconfig.json b/js/node/tsconfig.json index c8eb433577515..c154c3e148ed0 100644 --- a/js/node/tsconfig.json +++ b/js/node/tsconfig.json @@ -1,7 +1,6 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "module": "CommonJS", "outDir": "dist" }, "include": ["lib"] diff --git a/js/package-lock.json b/js/package-lock.json index be7b3c9cd7d30..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -6,33 +6,37 @@ "": { "license": "MIT", "devDependencies": { - "@types/fs-extra": "^11.0.1", - "@types/mocha": "^10.0.1", + "@types/fs-extra": "^11.0.2", + "@types/mocha": "^10.0.2", "@types/node": "^18.14.6", "@types/npmlog": "^4.1.4", - "@typescript-eslint/eslint-plugin": "^5.54.1", - "@typescript-eslint/parser": "^5.54.1", + "@typescript-eslint/eslint-plugin": "^6.7.4", + "@typescript-eslint/parser": "^6.7.4", "clang-format": "^1.8.0", - "dir-compare": "^4.0.0", - "eslint": "^8.35.0", + "dir-compare": "^4.2.0", + "esbuild": "^0.19.3", + "esbuild-plugin-polyfill-node": "^0.3.0", + "eslint": "^8.51.0", "eslint-plugin-header": "^3.1.1", - "eslint-plugin-import": "^2.27.5", - "eslint-plugin-jsdoc": "^40.0.1", + "eslint-plugin-import": "^2.28.1", + "eslint-plugin-jsdoc": "^46.8.2", "eslint-plugin-prefer-arrow": "^1.2.3", - "eslint-plugin-unicorn": "^46.0.0", - "fs-extra": "^11.1.0", + "eslint-plugin-unicorn": "^48.0.1", + "fs-extra": "^11.1.1", "jszip": "^3.10.1", "mocha": "^10.2.0", - "node-polyfill-webpack-plugin": "^2.0.1", "npmlog": "^7.0.1", - "prettier": "^3.0.0", - "terser": "^5.16.5", - "ts-loader": "^9.4.2", - "typescript": "^4.9.5", - "webpack": "^5.76.0", - "webpack-bundle-analyzer": "^4.8.0", - "webpack-cli": "^5.0.1", - "worker-loader": "^3.0.8" + "prettier": "^3.0.3", + "typescript": "^5.2.2" + } + }, + "node_modules/@aashutoshrathi/word-wrap": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz", + "integrity": "sha512-1Yjs2SvM8TflER/OD3cOjhWWOZb58A2t7wpE2S9XfBYTiIl+XFhQG2bjy4Pu1I+EAlCNUzRDYDdFwFYUKvXcIA==", + "dev": true, + "engines": { + "node": ">=0.10.0" } }, "node_modules/@babel/code-frame": { @@ -48,9 +52,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.19.1", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz", - "integrity": "sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w==", + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz", + "integrity": "sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==", "dev": true, "engines": { "node": ">=6.9.0" @@ -141,33 +145,376 @@ "node": ">=4" } }, - "node_modules/@discoveryjs/json-ext": { - "version": "0.5.7", - "resolved": "https://registry.npmjs.org/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz", - "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", + "node_modules/@es-joy/jsdoccomment": { + "version": "0.40.1", + "resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.40.1.tgz", + "integrity": "sha512-YORCdZSusAlBrFpZ77pJjc5r1bQs5caPWtAu+WWmiSo+8XaUzseapVrfAtiRFbQWnrBxxLLEwF6f6ZG/UgCQCg==", "dev": true, + "dependencies": { + "comment-parser": "1.4.0", + "esquery": "^1.5.0", + "jsdoc-type-pratt-parser": "~4.0.0" + }, "engines": { - "node": ">=10.0.0" + "node": ">=16" } }, - "node_modules/@es-joy/jsdoccomment": { - "version": "0.36.1", - "resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.36.1.tgz", - "integrity": "sha512-922xqFsTpHs6D0BUiG4toiyPOMc8/jafnWKxz1KWgS4XzKPy2qXf1Pe6UFuNSCQqt6tOuhAWXBNuuyUhJmw9Vg==", + "node_modules/@esbuild/android-arm": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.19.3.tgz", + "integrity": "sha512-Lemgw4io4VZl9GHJmjiBGzQ7ONXRfRPHcUEerndjwiSkbxzrpq0Uggku5MxxrXdwJ+pTj1qyw4jwTu7hkPsgIA==", + "cpu": [ + "arm" + ], "dev": true, - "dependencies": { - "comment-parser": "1.3.1", - "esquery": "^1.4.0", - "jsdoc-type-pratt-parser": "~3.1.0" - }, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.19.3.tgz", + "integrity": "sha512-w+Akc0vv5leog550kjJV9Ru+MXMR2VuMrui3C61mnysim0gkFCPOUTAfzTP0qX+HpN9Syu3YA3p1hf3EPqObRw==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.19.3.tgz", + "integrity": "sha512-FKQJKkK5MXcBHoNZMDNUAg1+WcZlV/cuXrWCoGF/TvdRiYS4znA0m5Il5idUwfxrE20bG/vU1Cr5e1AD6IEIjQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.3.tgz", + "integrity": "sha512-kw7e3FXU+VsJSSSl2nMKvACYlwtvZB8RUIeVShIEY6PVnuZ3c9+L9lWB2nWeeKWNNYDdtL19foCQ0ZyUL7nqGw==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.19.3.tgz", + "integrity": "sha512-tPfZiwF9rO0jW6Jh9ipi58N5ZLoSjdxXeSrAYypy4psA2Yl1dAMhM71KxVfmjZhJmxRjSnb29YlRXXhh3GqzYw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.19.3.tgz", + "integrity": "sha512-ERDyjOgYeKe0Vrlr1iLrqTByB026YLPzTytDTz1DRCYM+JI92Dw2dbpRHYmdqn6VBnQ9Bor6J8ZlNwdZdxjlSg==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.19.3.tgz", + "integrity": "sha512-nXesBZ2Ad1qL+Rm3crN7NmEVJ5uvfLFPLJev3x1j3feCQXfAhoYrojC681RhpdOph8NsvKBBwpYZHR7W0ifTTA==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.19.3.tgz", + "integrity": "sha512-zr48Cg/8zkzZCzDHNxXO/89bf9e+r4HtzNUPoz4GmgAkF1gFAFmfgOdCbR8zMbzFDGb1FqBBhdXUpcTQRYS1cQ==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.19.3.tgz", + "integrity": "sha512-qXvYKmXj8GcJgWq3aGvxL/JG1ZM3UR272SdPU4QSTzD0eymrM7leiZH77pvY3UetCy0k1xuXZ+VPvoJNdtrsWQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.19.3.tgz", + "integrity": "sha512-7XlCKCA0nWcbvYpusARWkFjRQNWNGlt45S+Q18UeS///K6Aw8bB2FKYe9mhVWy/XLShvCweOLZPrnMswIaDXQA==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.19.3.tgz", + "integrity": "sha512-qGTgjweER5xqweiWtUIDl9OKz338EQqCwbS9c2Bh5jgEH19xQ1yhgGPNesugmDFq+UUSDtWgZ264st26b3de8A==", + "cpu": [ + "loong64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.19.3.tgz", + "integrity": "sha512-gy1bFskwEyxVMFRNYSvBauDIWNggD6pyxUksc0MV9UOBD138dKTzr8XnM2R4mBsHwVzeuIH8X5JhmNs2Pzrx+A==", + "cpu": [ + "mips64el" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.19.3.tgz", + "integrity": "sha512-UrYLFu62x1MmmIe85rpR3qou92wB9lEXluwMB/STDzPF9k8mi/9UvNsG07Tt9AqwPQXluMQ6bZbTzYt01+Ue5g==", + "cpu": [ + "ppc64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.19.3.tgz", + "integrity": "sha512-9E73TfyMCbE+1AwFOg3glnzZ5fBAFK4aawssvuMgCRqCYzE0ylVxxzjEfut8xjmKkR320BEoMui4o/t9KA96gA==", + "cpu": [ + "riscv64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.19.3.tgz", + "integrity": "sha512-LlmsbuBdm1/D66TJ3HW6URY8wO6IlYHf+ChOUz8SUAjVTuaisfuwCOAgcxo3Zsu3BZGxmI7yt//yGOxV+lHcEA==", + "cpu": [ + "s390x" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.19.3.tgz", + "integrity": "sha512-ogV0+GwEmvwg/8ZbsyfkYGaLACBQWDvO0Kkh8LKBGKj9Ru8VM39zssrnu9Sxn1wbapA2qNS6BiLdwJZGouyCwQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.19.3.tgz", + "integrity": "sha512-o1jLNe4uzQv2DKXMlmEzf66Wd8MoIhLNO2nlQBHLtWyh2MitDG7sMpfCO3NTcoTMuqHjfufgUQDFRI5C+xsXQw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.19.3.tgz", + "integrity": "sha512-AZJCnr5CZgZOdhouLcfRdnk9Zv6HbaBxjcyhq0StNcvAdVZJSKIdOiPB9az2zc06ywl0ePYJz60CjdKsQacp5Q==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.19.3.tgz", + "integrity": "sha512-Acsujgeqg9InR4glTRvLKGZ+1HMtDm94ehTIHKhJjFpgVzZG9/pIcWW/HA/DoMfEyXmANLDuDZ2sNrWcjq1lxw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.19.3.tgz", + "integrity": "sha512-FSrAfjVVy7TifFgYgliiJOyYynhQmqgPj15pzLyJk8BUsnlWNwP/IAy6GAiB1LqtoivowRgidZsfpoYLZH586A==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.19.3.tgz", + "integrity": "sha512-xTScXYi12xLOWZ/sc5RBmMN99BcXp/eEf7scUC0oeiRoiT5Vvo9AycuqCp+xdpDyAU+LkrCqEpUS9fCSZF8J3Q==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.19.3.tgz", + "integrity": "sha512-FbUN+0ZRXsypPyWE2IwIkVjDkDnJoMJARWOcFZn4KPPli+QnKqF0z1anvfaYe3ev5HFCpRDLLBDHyOALLppWHw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], "engines": { - "node": "^14 || ^16 || ^17 || ^18 || ^19" + "node": ">=12" } }, "node_modules/@eslint-community/eslint-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.2.0.tgz", - "integrity": "sha512-gB8T4H4DEfX2IV9zGDJPOBgP1e/DbfCPDTtEqUMckpvzS1OYtva8JdFYBqMwYk7xAQ429WGF/UPqn8uQ//h2vQ==", + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", + "integrity": "sha512-1/sA4dwrzBAyeUoQ6oxahHKmrZvsnLCg4RfxW3ZFGGmQkSNQPFNLV9CUEFQP1x9EYXHTo5p6xdhZM1Ne9p/AfA==", "dev": true, "dependencies": { "eslint-visitor-keys": "^3.3.0" @@ -179,15 +526,24 @@ "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" } }, + "node_modules/@eslint-community/regexpp": { + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.9.1.tgz", + "integrity": "sha512-Y27x+MBLjXa+0JWDhykM3+JE+il3kHKAEqabfEWq3SDhZjLYb6/BHL/JKFnH3fe207JaXkyDo685Oc2Glt6ifA==", + "dev": true, + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, "node_modules/@eslint/eslintrc": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.0.0.tgz", - "integrity": "sha512-fluIaaV+GyV24CCu/ggiHdV+j4RNh85yQnAYS/G2mZODZgGmmlrgCydjUcV3YvxCm9x8nMAfThsqTni4KiXT4A==", + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.2.tgz", + "integrity": "sha512-+wvgpDsrB1YqAMdEUCcnTlpfVBH7Vqn6A/NT3D8WVXFIaKMlErPIZT3oCIAVCOtarRpMtelZLqJeU3t7WY6X6g==", "dev": true, "dependencies": { "ajv": "^6.12.4", "debug": "^4.3.2", - "espree": "^9.4.0", + "espree": "^9.6.0", "globals": "^13.19.0", "ignore": "^5.2.0", "import-fresh": "^3.2.1", @@ -203,18 +559,18 @@ } }, "node_modules/@eslint/js": { - "version": "8.35.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.35.0.tgz", - "integrity": "sha512-JXdzbRiWclLVoD8sNUjR443VVlYqiYmDVT6rGUEIEHU5YJW0gaVZwV2xgM7D4arkvASqD0IlLUVjHiFuxaftRw==", + "version": "8.51.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.51.0.tgz", + "integrity": "sha512-HxjQ8Qn+4SI3/AFv6sOrDB+g6PpUTDwSJiQqOrnneEk8L71161srI9gjzzZvYVbzHiVg/BvcH95+cK/zfIt4pg==", "dev": true, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, "node_modules/@humanwhocodes/config-array": { - "version": "0.11.8", - "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz", - "integrity": "sha512-UybHIJzJnR5Qc/MsD9Kr+RpO2h+/P1GhOwdiLPXK5TWk5sgTdu88bTD9UP+CKbPPh5Rni1u0GjAdYQLemG8g+g==", + "version": "0.11.11", + "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.11.tgz", + "integrity": "sha512-N2brEuAadi0CcdeMXUkhbZB84eskAc8MEX1By6qEchoVywSgXPIjou4rYsl0V3Hj0ZnuGycGCjdNgockbzeWNA==", "dev": true, "dependencies": { "@humanwhocodes/object-schema": "^1.2.1", @@ -244,64 +600,12 @@ "integrity": "sha512-ZnQMnLV4e7hDlUvw8H+U8ASL02SS2Gn6+9Ac3wGGLIe7+je2AeAOxPY+izIPJDfFDb7eDjev0Us8MO1iFRN8hA==", "dev": true }, - "node_modules/@jridgewell/gen-mapping": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz", - "integrity": "sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A==", - "dev": true, - "dependencies": { - "@jridgewell/set-array": "^1.0.1", - "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" - }, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/resolve-uri": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz", - "integrity": "sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/set-array": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", - "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/source-map": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.2.tgz", - "integrity": "sha512-m7O9o2uR8k2ObDysZYzdfhb08VuEml5oWGiosa1VdaPZ/A6QyPkAJuwN0Q1lhULOf6B7MtQmHENS743hWtCrgw==", - "dev": true, - "dependencies": { - "@jridgewell/gen-mapping": "^0.3.0", - "@jridgewell/trace-mapping": "^0.3.9" - } - }, - "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.4.14", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz", - "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==", + "node_modules/@jspm/core": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@jspm/core/-/core-2.0.1.tgz", + "integrity": "sha512-Lg3PnLp0QXpxwLIAuuJboLeRaIhrgJjeuh797QADg3xz8wGLugQOS5DpsE8A6i6Adgzf+bacllkKZG3J0tGfDw==", "dev": true }, - "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.17", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz", - "integrity": "sha512-MCNzAp77qzKca9+W/+I0+sEpaUnZoeasnghNeVc41VZCEKaCH73Vq3BZZ/SzWIgrqE4H4ceI+p+b6C0mHf9T4g==", - "dev": true, - "dependencies": { - "@jridgewell/resolve-uri": "3.1.0", - "@jridgewell/sourcemap-codec": "1.4.14" - } - }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -337,42 +641,10 @@ "node": ">= 8" } }, - "node_modules/@polka/url": { - "version": "1.0.0-next.21", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz", - "integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==", - "dev": true - }, - "node_modules/@types/eslint": { - "version": "8.21.1", - "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.21.1.tgz", - "integrity": "sha512-rc9K8ZpVjNcLs8Fp0dkozd5Pt2Apk1glO4Vgz8ix1u6yFByxfqo5Yavpy65o+93TAe24jr7v+eSBtFLvOQtCRQ==", - "dev": true, - "dependencies": { - "@types/estree": "*", - "@types/json-schema": "*" - } - }, - "node_modules/@types/eslint-scope": { - "version": "3.7.4", - "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", - "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", - "dev": true, - "dependencies": { - "@types/eslint": "*", - "@types/estree": "*" - } - }, - "node_modules/@types/estree": { - "version": "0.0.51", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-0.0.51.tgz", - "integrity": "sha512-CuPgU6f3eT/XgKKPqKd/gLZV1Xmvf1a2R5POBOGQa6uv82xpls89HU5zKeVoyR8XzHd1RGNOlQlvUe3CFkjWNQ==", - "dev": true - }, "node_modules/@types/fs-extra": { - "version": "11.0.1", - "resolved": "https://registry.npmjs.org/@types/fs-extra/-/fs-extra-11.0.1.tgz", - "integrity": "sha512-MxObHvNl4A69ofaTRU8DFqvgzzv8s9yRtaPPm5gud9HDNvpB3GPQFvNuTWAI59B9huVGV5jXYJwbCsmBsOGYWA==", + "version": "11.0.2", + "resolved": "https://registry.npmjs.org/@types/fs-extra/-/fs-extra-11.0.2.tgz", + "integrity": "sha512-c0hrgAOVYr21EX8J0jBMXGLMgJqVf/v6yxi0dLaJboW9aQPh16Id+z6w2Tx1hm+piJOLv8xPfVKZCLfjPw/IMQ==", "dev": true, "dependencies": { "@types/jsonfile": "*", @@ -380,9 +652,9 @@ } }, "node_modules/@types/json-schema": { - "version": "7.0.11", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.11.tgz", - "integrity": "sha512-wOuvG1SN4Us4rez+tylwwwCV1psiNVOkJeM3AUWUNWg/jDQY2+HE/444y5gc+jBmRqASOm2Oeh5c1axHobwRKQ==", + "version": "7.0.13", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", + "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", "dev": true }, "node_modules/@types/json5": { @@ -401,9 +673,9 @@ } }, "node_modules/@types/mocha": { - "version": "10.0.1", - "resolved": "https://registry.npmjs.org/@types/mocha/-/mocha-10.0.1.tgz", - "integrity": "sha512-/fvYntiO1GeICvqbQ3doGDIP97vWmvFt83GKguJ6prmQM2iXZfFcq6YE8KteFyRtX2/h5Hf91BYvPodJKFYv5Q==", + "version": "10.0.2", + "resolved": "https://registry.npmjs.org/@types/mocha/-/mocha-10.0.2.tgz", + "integrity": "sha512-NaHL0+0lLNhX6d9rs+NSt97WH/gIlRHmszXbQ/8/MV/eVcFNdeJ/GYhrFuUc8K7WuPhRhTSdMkCp8VMzhUq85w==", "dev": true }, "node_modules/@types/node": { @@ -425,38 +697,39 @@ "dev": true }, "node_modules/@types/semver": { - "version": "7.3.13", - "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.3.13.tgz", - "integrity": "sha512-21cFJr9z3g5dW8B0CVI9g2O9beqaThGQ6ZFBqHfwhzLDKUxaqTIy3vnfah/UPkfOiF2pLq+tGz+W8RyCskuslw==", + "version": "7.5.3", + "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.5.3.tgz", + "integrity": "sha512-OxepLK9EuNEIPxWNME+C6WwbRAOOI2o2BaQEGzz5Lu2e4Z5eDnEo+/aVEDMIXywoJitJ7xWd641wrGLZdtwRyw==", "dev": true }, "node_modules/@typescript-eslint/eslint-plugin": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.54.1.tgz", - "integrity": "sha512-a2RQAkosH3d3ZIV08s3DcL/mcGc2M/UC528VkPULFxR9VnVPT8pBu0IyBAJJmVsCmhVfwQX1v6q+QGnmSe1bew==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-6.7.4.tgz", + "integrity": "sha512-DAbgDXwtX+pDkAHwiGhqP3zWUGpW49B7eqmgpPtg+BKJXwdct79ut9+ifqOFPJGClGKSHXn2PTBatCnldJRUoA==", "dev": true, "dependencies": { - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/type-utils": "5.54.1", - "@typescript-eslint/utils": "5.54.1", + "@eslint-community/regexpp": "^4.5.1", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/type-utils": "6.7.4", + "@typescript-eslint/utils": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4", - "grapheme-splitter": "^1.0.4", - "ignore": "^5.2.0", - "natural-compare-lite": "^1.4.0", - "regexpp": "^3.2.0", - "semver": "^7.3.7", - "tsutils": "^3.21.0" + "graphemer": "^1.4.0", + "ignore": "^5.2.4", + "natural-compare": "^1.4.0", + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "@typescript-eslint/parser": "^5.0.0", - "eslint": "^6.0.0 || ^7.0.0 || ^8.0.0" + "@typescript-eslint/parser": "^6.0.0 || ^6.0.0-alpha", + "eslint": "^7.0.0 || ^8.0.0" }, "peerDependenciesMeta": { "typescript": { @@ -465,25 +738,26 @@ } }, "node_modules/@typescript-eslint/parser": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.54.1.tgz", - "integrity": "sha512-8zaIXJp/nG9Ff9vQNh7TI+C3nA6q6iIsGJ4B4L6MhZ7mHnTMR4YP5vp2xydmFXIy8rpyIVbNAG44871LMt6ujg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-6.7.4.tgz", + "integrity": "sha512-I5zVZFY+cw4IMZUeNCU7Sh2PO5O57F7Lr0uyhgCJmhN/BuTlnc55KxPonR4+EM3GBdfiCyGZye6DgMjtubQkmA==", "dev": true, "dependencies": { - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/typescript-estree": "5.54.1", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/typescript-estree": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^6.0.0 || ^7.0.0 || ^8.0.0" + "eslint": "^7.0.0 || ^8.0.0" }, "peerDependenciesMeta": { "typescript": { @@ -492,16 +766,16 @@ } }, "node_modules/@typescript-eslint/scope-manager": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.54.1.tgz", - "integrity": "sha512-zWKuGliXxvuxyM71UA/EcPxaviw39dB2504LqAmFDjmkpO8qNLHcmzlh6pbHs1h/7YQ9bnsO8CCcYCSA8sykUg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.7.4.tgz", + "integrity": "sha512-SdGqSLUPTXAXi7c3Ob7peAGVnmMoGzZ361VswK2Mqf8UOYcODiYvs8rs5ILqEdfvX1lE7wEZbLyELCW+Yrql1A==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/visitor-keys": "5.54.1" + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", @@ -509,25 +783,25 @@ } }, "node_modules/@typescript-eslint/type-utils": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.54.1.tgz", - "integrity": "sha512-WREHsTz0GqVYLIbzIZYbmUUr95DKEKIXZNH57W3s+4bVnuF1TKe2jH8ZNH8rO1CeMY3U4j4UQeqPNkHMiGem3g==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-6.7.4.tgz", + "integrity": "sha512-n+g3zi1QzpcAdHFP9KQF+rEFxMb2KxtnJGID3teA/nxKHOVi3ylKovaqEzGBbVY2pBttU6z85gp0D00ufLzViQ==", "dev": true, "dependencies": { - "@typescript-eslint/typescript-estree": "5.54.1", - "@typescript-eslint/utils": "5.54.1", + "@typescript-eslint/typescript-estree": "6.7.4", + "@typescript-eslint/utils": "6.7.4", "debug": "^4.3.4", - "tsutils": "^3.21.0" + "ts-api-utils": "^1.0.1" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "*" + "eslint": "^7.0.0 || ^8.0.0" }, "peerDependenciesMeta": { "typescript": { @@ -536,12 +810,12 @@ } }, "node_modules/@typescript-eslint/types": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.54.1.tgz", - "integrity": "sha512-G9+1vVazrfAfbtmCapJX8jRo2E4MDXxgm/IMOF4oGh3kq7XuK3JRkOg6y2Qu1VsTRmWETyTkWt1wxy7X7/yLkw==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.7.4.tgz", + "integrity": "sha512-o9XWK2FLW6eSS/0r/tgjAGsYasLAnOWg7hvZ/dGYSSNjCh+49k5ocPN8OmG5aZcSJ8pclSOyVKP2x03Sj+RrCA==", "dev": true, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", @@ -549,21 +823,21 @@ } }, "node_modules/@typescript-eslint/typescript-estree": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.54.1.tgz", - "integrity": "sha512-bjK5t+S6ffHnVwA0qRPTZrxKSaFYocwFIkZx5k7pvWfsB1I57pO/0M0Skatzzw1sCkjJ83AfGTL0oFIFiDX3bg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.7.4.tgz", + "integrity": "sha512-ty8b5qHKatlNYd9vmpHooQz3Vki3gG+3PchmtsA4TgrZBKWHNjWfkQid7K7xQogBqqc7/BhGazxMD5vr6Ha+iQ==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/visitor-keys": "5.54.1", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4", "globby": "^11.1.0", "is-glob": "^4.0.3", - "semver": "^7.3.7", - "tsutils": "^3.21.0" + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", @@ -576,250 +850,47 @@ } }, "node_modules/@typescript-eslint/utils": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.54.1.tgz", - "integrity": "sha512-IY5dyQM8XD1zfDe5X8jegX6r2EVU5o/WJnLu/znLPWCBF7KNGC+adacXnt5jEYS9JixDcoccI6CvE4RCjHMzCQ==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.7.4.tgz", + "integrity": "sha512-PRQAs+HUn85Qdk+khAxsVV+oULy3VkbH3hQ8hxLRJXWBEd7iI+GbQxH5SEUSH7kbEoTp6oT1bOwyga24ELALTA==", "dev": true, "dependencies": { - "@types/json-schema": "^7.0.9", - "@types/semver": "^7.3.12", - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/typescript-estree": "5.54.1", - "eslint-scope": "^5.1.1", - "eslint-utils": "^3.0.0", - "semver": "^7.3.7" + "@eslint-community/eslint-utils": "^4.4.0", + "@types/json-schema": "^7.0.12", + "@types/semver": "^7.5.0", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/typescript-estree": "6.7.4", + "semver": "^7.5.4" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^6.0.0 || ^7.0.0 || ^8.0.0" + "eslint": "^7.0.0 || ^8.0.0" } }, "node_modules/@typescript-eslint/visitor-keys": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.54.1.tgz", - "integrity": "sha512-q8iSoHTgwCfgcRJ2l2x+xCbu8nBlRAlsQ33k24Adj8eoVBE0f8dUeI+bAa8F84Mv05UGbAx57g2zrRsYIooqQg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.7.4.tgz", + "integrity": "sha512-pOW37DUhlTZbvph50x5zZCkFn3xzwkGtNoJHzIM3svpiSkJzwOYr/kVBaXmf+RAQiUDs1AHEZVNPg6UJCJpwRA==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.54.1", - "eslint-visitor-keys": "^3.3.0" + "@typescript-eslint/types": "6.7.4", + "eslint-visitor-keys": "^3.4.1" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + "node": "^16.0.0 || >=18.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/typescript-eslint" } }, - "node_modules/@webassemblyjs/ast": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.1.tgz", - "integrity": "sha512-ukBh14qFLjxTQNTXocdyksN5QdM28S1CxHt2rdskFyL+xFV7VremuBLVbmCePj+URalXBENx/9Lm7lnhihtCSw==", - "dev": true, - "dependencies": { - "@webassemblyjs/helper-numbers": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1" - } - }, - "node_modules/@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.1.tgz", - "integrity": "sha512-iGRfyc5Bq+NnNuX8b5hwBrRjzf0ocrJPI6GWFodBFzmFnyvrQ83SHKhmilCU/8Jv67i4GJZBMhEzltxzcNagtQ==", - "dev": true - }, - "node_modules/@webassemblyjs/helper-api-error": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.1.tgz", - "integrity": "sha512-RlhS8CBCXfRUR/cwo2ho9bkheSXG0+NwooXcc3PAILALf2QLdFyj7KGsKRbVc95hZnhnERon4kW/D3SZpp6Tcg==", - "dev": true - }, - "node_modules/@webassemblyjs/helper-buffer": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.1.tgz", - "integrity": "sha512-gwikF65aDNeeXa8JxXa2BAk+REjSyhrNC9ZwdT0f8jc4dQQeDQ7G4m0f2QCLPJiMTTO6wfDmRmj/pW0PsUvIcA==", - "dev": true - }, - "node_modules/@webassemblyjs/helper-numbers": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.1.tgz", - "integrity": "sha512-vDkbxiB8zfnPdNK9Rajcey5C0w+QJugEglN0of+kmO8l7lDb77AnlKYQF7aarZuCrv+l0UvqL+68gSDr3k9LPQ==", - "dev": true, - "dependencies": { - "@webassemblyjs/floating-point-hex-parser": "1.11.1", - "@webassemblyjs/helper-api-error": "1.11.1", - "@xtuc/long": "4.2.2" - } - }, - "node_modules/@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.1.tgz", - "integrity": "sha512-PvpoOGiJwXeTrSf/qfudJhwlvDQxFgelbMqtq52WWiXC6Xgg1IREdngmPN3bs4RoO83PnL/nFrxucXj1+BX62Q==", - "dev": true - }, - "node_modules/@webassemblyjs/helper-wasm-section": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.1.tgz", - "integrity": "sha512-10P9No29rYX1j7F3EVPX3JvGPQPae+AomuSTPiF9eBQeChHI6iqjMIwR9JmOJXwpnn/oVGDk7I5IlskuMwU/pg==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1" - } - }, - "node_modules/@webassemblyjs/ieee754": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.1.tgz", - "integrity": "sha512-hJ87QIPtAMKbFq6CGTkZYJivEwZDbQUgYd3qKSadTNOhVY7p+gfP6Sr0lLRVTaG1JjFj+r3YchoqRYxNH3M0GQ==", - "dev": true, - "dependencies": { - "@xtuc/ieee754": "^1.2.0" - } - }, - "node_modules/@webassemblyjs/leb128": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.1.tgz", - "integrity": "sha512-BJ2P0hNZ0u+Th1YZXJpzW6miwqQUGcIHT1G/sf72gLVD9DZ5AdYTqPNbHZh6K1M5VmKvFXwGSWZADz+qBWxeRw==", - "dev": true, - "dependencies": { - "@xtuc/long": "4.2.2" - } - }, - "node_modules/@webassemblyjs/utf8": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.1.tgz", - "integrity": "sha512-9kqcxAEdMhiwQkHpkNiorZzqpGrodQQ2IGrHHxCy+Ozng0ofyMA0lTqiLkVs1uzTRejX+/O0EOT7KxqVPuXosQ==", - "dev": true - }, - "node_modules/@webassemblyjs/wasm-edit": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.1.tgz", - "integrity": "sha512-g+RsupUC1aTHfR8CDgnsVRVZFJqdkFHpsHMfJuWQzWU3tvnLC07UqHICfP+4XyL2tnr1amvl1Sdp06TnYCmVkA==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/helper-wasm-section": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1", - "@webassemblyjs/wasm-opt": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1", - "@webassemblyjs/wast-printer": "1.11.1" - } - }, - "node_modules/@webassemblyjs/wasm-gen": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.1.tgz", - "integrity": "sha512-F7QqKXwwNlMmsulj6+O7r4mmtAlCWfO/0HdgOxSklZfQcDu0TpLiD1mRt/zF25Bk59FIjEuGAIyn5ei4yMfLhA==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/ieee754": "1.11.1", - "@webassemblyjs/leb128": "1.11.1", - "@webassemblyjs/utf8": "1.11.1" - } - }, - "node_modules/@webassemblyjs/wasm-opt": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.1.tgz", - "integrity": "sha512-VqnkNqnZlU5EB64pp1l7hdm3hmQw7Vgqa0KF/KCNO9sIpI6Fk6brDEiX+iCOYrvMuBWDws0NkTOxYEb85XQHHw==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1" - } - }, - "node_modules/@webassemblyjs/wasm-parser": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.1.tgz", - "integrity": "sha512-rrBujw+dJu32gYB7/Lup6UhdkPx9S9SnobZzRVL7VcBH9Bt9bCBLEuX/YXOOtBsOZ4NQrRykKhffRWHvigQvOA==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-api-error": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/ieee754": "1.11.1", - "@webassemblyjs/leb128": "1.11.1", - "@webassemblyjs/utf8": "1.11.1" - } - }, - "node_modules/@webassemblyjs/wast-printer": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.1.tgz", - "integrity": "sha512-IQboUWM4eKzWW+N/jij2sRatKMh99QEelo3Eb2q0qXkvPRISAj8Qxtmw5itwqK+TTkBuUIE45AxYPToqPtL5gg==", - "dev": true, - "dependencies": { - "@webassemblyjs/ast": "1.11.1", - "@xtuc/long": "4.2.2" - } - }, - "node_modules/@webpack-cli/configtest": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.0.1.tgz", - "integrity": "sha512-njsdJXJSiS2iNbQVS0eT8A/KPnmyH4pv1APj2K0d1wrZcBLw+yppxOy4CGqa0OxDJkzfL/XELDhD8rocnIwB5A==", - "dev": true, - "engines": { - "node": ">=14.15.0" - }, - "peerDependencies": { - "webpack": "5.x.x", - "webpack-cli": "5.x.x" - } - }, - "node_modules/@webpack-cli/info": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/info/-/info-2.0.1.tgz", - "integrity": "sha512-fE1UEWTwsAxRhrJNikE7v4EotYflkEhBL7EbajfkPlf6E37/2QshOy/D48Mw8G5XMFlQtS6YV42vtbG9zBpIQA==", - "dev": true, - "engines": { - "node": ">=14.15.0" - }, - "peerDependencies": { - "webpack": "5.x.x", - "webpack-cli": "5.x.x" - } - }, - "node_modules/@webpack-cli/serve": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/serve/-/serve-2.0.1.tgz", - "integrity": "sha512-0G7tNyS+yW8TdgHwZKlDWYXFA6OJQnoLCQvYKkQP0Q2X205PSQ6RNUj0M+1OB/9gRQaUZ/ccYfaxd0nhaWKfjw==", - "dev": true, - "engines": { - "node": ">=14.15.0" - }, - "peerDependencies": { - "webpack": "5.x.x", - "webpack-cli": "5.x.x" - }, - "peerDependenciesMeta": { - "webpack-dev-server": { - "optional": true - } - } - }, - "node_modules/@xtuc/ieee754": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@xtuc/ieee754/-/ieee754-1.2.0.tgz", - "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", - "dev": true - }, - "node_modules/@xtuc/long": { - "version": "4.2.2", - "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", - "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", - "dev": true - }, "node_modules/abort-controller": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", @@ -833,9 +904,9 @@ } }, "node_modules/acorn": { - "version": "8.8.2", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz", - "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==", + "version": "8.10.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", + "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", "dev": true, "bin": { "acorn": "bin/acorn" @@ -844,15 +915,6 @@ "node": ">=0.4.0" } }, - "node_modules/acorn-import-assertions": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.8.0.tgz", - "integrity": "sha512-m7VZ3jwz4eK6A4Vtt8Ew1/mNbP24u0FhdyfA7fSvnJR6LMdfOYnmuIrrJAgrYfYJ10F/otaHTtrtrtmHdMNzEw==", - "dev": true, - "peerDependencies": { - "acorn": "^8" - } - }, "node_modules/acorn-jsx": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", @@ -862,15 +924,6 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, - "node_modules/acorn-walk": { - "version": "8.2.0", - "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz", - "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==", - "dev": true, - "engines": { - "node": ">=0.4.0" - } - }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -887,15 +940,6 @@ "url": "https://github.com/sponsors/epoberezkin" } }, - "node_modules/ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", - "dev": true, - "peerDependencies": { - "ajv": "^6.9.1" - } - }, "node_modules/ansi-colors": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/ansi-colors/-/ansi-colors-4.1.1.tgz", @@ -948,6 +992,15 @@ "integrity": "sha512-lYe4Gx7QT+MKGbDsA+Z+he/Wtef0BiwDOlK/XkBrdfsh9J/jPPXbX0tE9x9cl27Tmu5gg3QUbUrQYa/y+KOHPQ==", "dev": true }, + "node_modules/are-docs-informative": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/are-docs-informative/-/are-docs-informative-0.0.2.tgz", + "integrity": "sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==", + "dev": true, + "engines": { + "node": ">=14" + } + }, "node_modules/are-we-there-yet": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-4.0.0.tgz", @@ -982,6 +1035,19 @@ "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", "dev": true }, + "node_modules/array-buffer-byte-length": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/array-buffer-byte-length/-/array-buffer-byte-length-1.0.0.tgz", + "integrity": "sha512-LPuwb2P+NrQw3XhxGc36+XSvuBPopovXYTR9Ew++Du9Yb/bx5AzBfrIsBoj0EZUifjQU+sHL21sseZ3jerWO/A==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "is-array-buffer": "^3.0.1" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/array-includes": { "version": "3.1.6", "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.6.tgz", @@ -1010,6 +1076,25 @@ "node": ">=8" } }, + "node_modules/array.prototype.findlastindex": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/array.prototype.findlastindex/-/array.prototype.findlastindex-1.2.3.tgz", + "integrity": "sha512-LzLoiOMAxvy+Gd3BAq3B7VeIgPdo+Q8hthvKtXybMvRV0jrXfJM/t8mw7nNlpEcVlVUnCnM2KSX4XU5HmpodOA==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "es-shim-unscopables": "^1.0.0", + "get-intrinsic": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/array.prototype.flat": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/array.prototype.flat/-/array.prototype.flat-1.3.1.tgz", @@ -1046,34 +1131,25 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/asn1.js": { - "version": "5.4.1", - "resolved": "https://registry.npmjs.org/asn1.js/-/asn1.js-5.4.1.tgz", - "integrity": "sha512-+I//4cYPccV8LdmBLiX8CYvf9Sp3vQsrqu2QNXRcrbiWvcx/UdlFiqUJJzxRQxgsZmvhXhn4cSKeSmoFjVdupA==", - "dev": true, - "dependencies": { - "bn.js": "^4.0.0", - "inherits": "^2.0.1", - "minimalistic-assert": "^1.0.0", - "safer-buffer": "^2.1.0" - } - }, - "node_modules/asn1.js/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, - "node_modules/assert": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/assert/-/assert-2.0.0.tgz", - "integrity": "sha512-se5Cd+js9dXJnu6Ag2JFc00t+HmHOen+8Q+L7O9zI0PqQXr20uk2J0XQqMxZEeo5U50o8Nvmmx7dZrl+Ufr35A==", + "node_modules/arraybuffer.prototype.slice": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/arraybuffer.prototype.slice/-/arraybuffer.prototype.slice-1.0.2.tgz", + "integrity": "sha512-yMBKppFur/fbHu9/6USUe03bZ4knMYiwFBcyiaXB8Go0qNehwX6inYPzK9U0NeQvGxKthcmHcaR8P5MStSRBAw==", "dev": true, "dependencies": { - "es6-object-assign": "^1.1.0", - "is-nan": "^1.2.1", - "object-is": "^1.0.1", - "util": "^0.12.0" + "array-buffer-byte-length": "^1.0.0", + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "get-intrinsic": "^1.2.1", + "is-array-buffer": "^3.0.2", + "is-shared-array-buffer": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, "node_modules/async": { @@ -1120,15 +1196,6 @@ } ] }, - "node_modules/big.js": { - "version": "5.2.2", - "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", - "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", - "dev": true, - "engines": { - "node": "*" - } - }, "node_modules/binary-extensions": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", @@ -1138,12 +1205,6 @@ "node": ">=8" } }, - "node_modules/bn.js": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-5.2.1.tgz", - "integrity": "sha512-eXRvHzWyYPBuB4NBy0cmYQjGitUrtqwbvlzP3G6VFnNRbsZQIxQ10PbKKHt8gZ/HW/D/747aDl+QkDqg3KQLMQ==", - "dev": true - }, "node_modules/brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -1166,153 +1227,12 @@ "node": ">=8" } }, - "node_modules/brorand": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/brorand/-/brorand-1.1.0.tgz", - "integrity": "sha512-cKV8tMCEpQs4hK/ik71d6LrPOnpkpGBR0wzxqr68g2m/LB2GxVYQroAjMJZRVM1Y4BCjCKc3vAamxSzOY2RP+w==", - "dev": true - }, "node_modules/browser-stdout": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/browser-stdout/-/browser-stdout-1.3.1.tgz", "integrity": "sha512-qhAVI1+Av2X7qelOfAIYwXONood6XlZE/fXaBSmW/T5SzLAmCgzi+eiWE7fUvbHaeNBQH13UftjpXxsfLkMpgw==", "dev": true }, - "node_modules/browserify-aes": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/browserify-aes/-/browserify-aes-1.2.0.tgz", - "integrity": "sha512-+7CHXqGuspUn/Sl5aO7Ea0xWGAtETPXNSAjHo48JfLdPWcMng33Xe4znFvQweqc/uzk5zSOI3H52CYnjCfb5hA==", - "dev": true, - "dependencies": { - "buffer-xor": "^1.0.3", - "cipher-base": "^1.0.0", - "create-hash": "^1.1.0", - "evp_bytestokey": "^1.0.3", - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - } - }, - "node_modules/browserify-cipher": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/browserify-cipher/-/browserify-cipher-1.0.1.tgz", - "integrity": "sha512-sPhkz0ARKbf4rRQt2hTpAHqn47X3llLkUGn+xEJzLjwY8LRs2p0v7ljvI5EyoRO/mexrNunNECisZs+gw2zz1w==", - "dev": true, - "dependencies": { - "browserify-aes": "^1.0.4", - "browserify-des": "^1.0.0", - "evp_bytestokey": "^1.0.0" - } - }, - "node_modules/browserify-des": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/browserify-des/-/browserify-des-1.0.2.tgz", - "integrity": "sha512-BioO1xf3hFwz4kc6iBhI3ieDFompMhrMlnDFC4/0/vd5MokpuAc3R+LYbwTA9A5Yc9pq9UYPqffKpW2ObuwX5A==", - "dev": true, - "dependencies": { - "cipher-base": "^1.0.1", - "des.js": "^1.0.0", - "inherits": "^2.0.1", - "safe-buffer": "^5.1.2" - } - }, - "node_modules/browserify-rsa": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/browserify-rsa/-/browserify-rsa-4.1.0.tgz", - "integrity": "sha512-AdEER0Hkspgno2aR97SAf6vi0y0k8NuOpGnVH3O99rcA5Q6sh8QxcngtHuJ6uXwnfAXNM4Gn1Gb7/MV1+Ymbog==", - "dev": true, - "dependencies": { - "bn.js": "^5.0.0", - "randombytes": "^2.0.1" - } - }, - "node_modules/browserify-sign": { - "version": "4.2.1", - "resolved": "https://registry.npmjs.org/browserify-sign/-/browserify-sign-4.2.1.tgz", - "integrity": "sha512-/vrA5fguVAKKAVTNJjgSm1tRQDHUU6DbwO9IROu/0WAzC8PKhucDSh18J0RMvVeHAn5puMd+QHC2erPRNf8lmg==", - "dev": true, - "dependencies": { - "bn.js": "^5.1.1", - "browserify-rsa": "^4.0.1", - "create-hash": "^1.2.0", - "create-hmac": "^1.1.7", - "elliptic": "^6.5.3", - "inherits": "^2.0.4", - "parse-asn1": "^5.1.5", - "readable-stream": "^3.6.0", - "safe-buffer": "^5.2.0" - } - }, - "node_modules/browserify-sign/node_modules/readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/browserify-sign/node_modules/safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, - "node_modules/browserify-zlib": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/browserify-zlib/-/browserify-zlib-0.2.0.tgz", - "integrity": "sha512-Z942RysHXmJrhqk88FmKBVq/v5tqmSkDz7p54G/MGyjMnCFFnC79XWNbg+Vta8W6Wb2qtSZTSxIGkJrRpCFEiA==", - "dev": true, - "dependencies": { - "pako": "~1.0.5" - } - }, - "node_modules/browserslist": { - "version": "4.21.5", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.5.tgz", - "integrity": "sha512-tUkiguQGW7S3IhB7N+c2MV/HZPSCPAAiYBZXLsBhFB/PCy6ZKKsZrmBayHV9fdGV/ARIfJ14NkxKzRDjvp7L6w==", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/browserslist" - } - ], - "dependencies": { - "caniuse-lite": "^1.0.30001449", - "electron-to-chromium": "^1.4.284", - "node-releases": "^2.0.8", - "update-browserslist-db": "^1.0.10" - }, - "bin": { - "browserslist": "cli.js" - }, - "engines": { - "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" - } - }, "node_modules/buffer": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", @@ -1337,18 +1257,6 @@ "ieee754": "^1.2.1" } }, - "node_modules/buffer-from": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", - "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", - "dev": true - }, - "node_modules/buffer-xor": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/buffer-xor/-/buffer-xor-1.0.3.tgz", - "integrity": "sha512-571s0T7nZWK6vB67HI5dyUF7wXiNcfaPPPTl6zYCNApANjIvYJTg7hlud/+cJpdAhS7dVzqMLmfhfHR3rAcOjQ==", - "dev": true - }, "node_modules/builtin-modules": { "version": "3.3.0", "resolved": "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz", @@ -1361,12 +1269,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/builtin-status-codes": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/builtin-status-codes/-/builtin-status-codes-3.0.0.tgz", - "integrity": "sha512-HpGFw18DgFWlncDfjTa2rcQ4W88O1mC8e8yZ2AvQY5KDaktSTwo+KRf6nHK6FRI5FyRyb/5T6+TSxfP7QyGsmQ==", - "dev": true - }, "node_modules/call-bind": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", @@ -1401,22 +1303,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/caniuse-lite": { - "version": "1.0.30001460", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001460.tgz", - "integrity": "sha512-Bud7abqjvEjipUkpLs4D7gR0l8hBYBHoa+tGtKJHvT2AYzLp1z7EmVkUT4ERpVUfca8S2HGIVs883D8pUH1ZzQ==", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/caniuse-lite" - } - ] - }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -1472,15 +1358,6 @@ "node": ">= 6" } }, - "node_modules/chrome-trace-event": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.3.tgz", - "integrity": "sha512-p3KULyQg4S7NIHixdwbGX+nFHkoBiA4YQmyWtjb8XngSKV124nJmRysgAeujbUVb15vh+RvFUfCPqU7rXk+hZg==", - "dev": true, - "engines": { - "node": ">=6.0" - } - }, "node_modules/ci-info": { "version": "3.8.0", "resolved": "https://registry.npmjs.org/ci-info/-/ci-info-3.8.0.tgz", @@ -1496,16 +1373,6 @@ "node": ">=8" } }, - "node_modules/cipher-base": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/cipher-base/-/cipher-base-1.0.4.tgz", - "integrity": "sha512-Kkht5ye6ZGmwv40uUDZztayT2ThLQGfnj/T71N/XzeZeo3nf8foyW7zGTsPYkEya3m5f3cAypH+qe7YOrM1U2Q==", - "dev": true, - "dependencies": { - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - } - }, "node_modules/clang-format": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/clang-format/-/clang-format-1.8.0.tgz", @@ -1554,20 +1421,6 @@ "wrap-ansi": "^7.0.0" } }, - "node_modules/clone-deep": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", - "integrity": "sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==", - "dev": true, - "dependencies": { - "is-plain-object": "^2.0.4", - "kind-of": "^6.0.2", - "shallow-clone": "^3.0.0" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/color-convert": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", @@ -1595,22 +1448,10 @@ "color-support": "bin.js" } }, - "node_modules/colorette": { - "version": "2.0.19", - "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.19.tgz", - "integrity": "sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==", - "dev": true - }, - "node_modules/commander": { - "version": "2.20.3", - "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", - "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", - "dev": true - }, "node_modules/comment-parser": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.3.1.tgz", - "integrity": "sha512-B52sN2VNghyq5ofvUsqZjmk6YkihBX5vMSChmSK9v4ShjKf3Vk5Xcmgpw4o+iIgtrnM/u5FiMpz9VKb8lpBveA==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.4.0.tgz", + "integrity": "sha512-QLyTNiZ2KDOibvFPlZ6ZngVsZ/0gYnE6uTXi5aoDg8ed3AkJAz4sEje3Y8a29hQ1s6A99MZXe47fLAXQ1rTqaw==", "dev": true, "engines": { "node": ">= 12.0.0" @@ -1622,73 +1463,18 @@ "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", "dev": true }, - "node_modules/console-browserify": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/console-browserify/-/console-browserify-1.2.0.tgz", - "integrity": "sha512-ZMkYO/LkF17QvCPqM0gxw8yUzigAOZOSWSHg91FH6orS7vcEj5dVZTidN2fQ14yBSdg97RqhSNwLUXInd52OTA==", - "dev": true - }, "node_modules/console-control-strings": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==", "dev": true }, - "node_modules/constants-browserify": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/constants-browserify/-/constants-browserify-1.0.0.tgz", - "integrity": "sha512-xFxOwqIzR/e1k1gLiWEophSCMqXcwVHIH7akf7b/vxcUeGunlj3hvZaaqxwHsTgn+IndtkQJgSztIDWeumWJDQ==", - "dev": true - }, "node_modules/core-util-is": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", "dev": true }, - "node_modules/create-ecdh": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/create-ecdh/-/create-ecdh-4.0.4.tgz", - "integrity": "sha512-mf+TCx8wWc9VpuxfP2ht0iSISLZnt0JgWlrOKZiNqyUZWnjIaCIVNQArMHnCZKfEYRg6IM7A+NeJoN8gf/Ws0A==", - "dev": true, - "dependencies": { - "bn.js": "^4.1.0", - "elliptic": "^6.5.3" - } - }, - "node_modules/create-ecdh/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, - "node_modules/create-hash": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/create-hash/-/create-hash-1.2.0.tgz", - "integrity": "sha512-z00bCGNHDG8mHAkP7CtT1qVu+bFQUPjYq/4Iv3C3kWjTFV10zIjfSoeqXo9Asws8gwSHDGj/hl2u4OGIjapeCg==", - "dev": true, - "dependencies": { - "cipher-base": "^1.0.1", - "inherits": "^2.0.1", - "md5.js": "^1.3.4", - "ripemd160": "^2.0.1", - "sha.js": "^2.4.0" - } - }, - "node_modules/create-hmac": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/create-hmac/-/create-hmac-1.1.7.tgz", - "integrity": "sha512-MJG9liiZ+ogc4TzUwuvbER1JRdgvUFSB5+VR/g5h82fGaIRWMWddtKBHi7/sVhfjQZ6SehlyhvQYrcYkaUIpLg==", - "dev": true, - "dependencies": { - "cipher-base": "^1.0.3", - "create-hash": "^1.1.0", - "inherits": "^2.0.1", - "ripemd160": "^2.0.0", - "safe-buffer": "^5.0.1", - "sha.js": "^2.4.8" - } - }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -1703,28 +1489,6 @@ "node": ">= 8" } }, - "node_modules/crypto-browserify": { - "version": "3.12.0", - "resolved": "https://registry.npmjs.org/crypto-browserify/-/crypto-browserify-3.12.0.tgz", - "integrity": "sha512-fz4spIh+znjO2VjL+IdhEpRJ3YN6sMzITSBijk6FK2UvTqruSQW+/cCZTSNsMiZNvUeq0CqurF+dAbyiGOY6Wg==", - "dev": true, - "dependencies": { - "browserify-cipher": "^1.0.0", - "browserify-sign": "^4.0.0", - "create-ecdh": "^4.0.0", - "create-hash": "^1.1.0", - "create-hmac": "^1.1.0", - "diffie-hellman": "^5.0.0", - "inherits": "^2.0.1", - "pbkdf2": "^3.0.3", - "public-encrypt": "^4.0.0", - "randombytes": "^2.0.0", - "randomfill": "^1.0.3" - }, - "engines": { - "node": "*" - } - }, "node_modules/debug": { "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", @@ -1760,6 +1524,20 @@ "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", "dev": true }, + "node_modules/define-data-property": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.0.tgz", + "integrity": "sha512-UzGwzcjyv3OtAvolTj1GoyNYzfFR+iqbGjcnBEENZVCpM4/Ng1yhGNvS3lR/xDS74Tb2wGG9WzNSNIOS9UVb2g==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.2.1", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -1782,16 +1560,6 @@ "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", "dev": true }, - "node_modules/des.js": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/des.js/-/des.js-1.0.1.tgz", - "integrity": "sha512-Q0I4pfFrv2VPd34/vfLrFOoRmlYj3OV50i7fskps1jZWK1kApMWWT9G6RRUeYedLcBDIhnSDaUvJMb3AhUlaEA==", - "dev": true, - "dependencies": { - "inherits": "^2.0.1", - "minimalistic-assert": "^1.0.0" - } - }, "node_modules/diff": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz", @@ -1801,30 +1569,13 @@ "node": ">=0.3.1" } }, - "node_modules/diffie-hellman": { - "version": "5.0.3", - "resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz", - "integrity": "sha512-kqag/Nl+f3GwyK25fhUMYj81BUOrZ9IuJsjIcDE5icNM9FJHAVm3VcUDxdLPoQtTuUylWm6ZIknYJwwaPxsUzg==", - "dev": true, - "dependencies": { - "bn.js": "^4.1.0", - "miller-rabin": "^4.0.0", - "randombytes": "^2.0.0" - } - }, - "node_modules/diffie-hellman/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, "node_modules/dir-compare": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/dir-compare/-/dir-compare-4.0.0.tgz", - "integrity": "sha512-wC7thVKL3V656tO61rbEDE4LTeeYrUC2pAUL00AaXYghBhjjVNRyBlpH6POzb44ZuK23OSrqF6TbSC/QYeqfAg==", + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/dir-compare/-/dir-compare-4.2.0.tgz", + "integrity": "sha512-2xMCmOoMrdQIPHdsTawECdNPwlVFB9zGcz3kuhmBO6U3oU+UQjsue0i8ayLKpgBcm+hcXPMVSGUN9d+pvJ6+VQ==", "dev": true, "dependencies": { - "minimatch": "^3.0.4", + "minimatch": "^3.0.5", "p-limit": "^3.1.0 " } }, @@ -1852,91 +1603,12 @@ "node": ">=6.0.0" } }, - "node_modules/domain-browser": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/domain-browser/-/domain-browser-4.22.0.tgz", - "integrity": "sha512-IGBwjF7tNk3cwypFNH/7bfzBcgSCbaMOD3GsaY1AU/JRrnHnYgEM0+9kQt52iZxjNsjBtJYtao146V+f8jFZNw==", - "dev": true, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://bevry.me/fund" - } - }, - "node_modules/duplexer": { - "version": "0.1.2", - "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", - "integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==", - "dev": true - }, - "node_modules/electron-to-chromium": { - "version": "1.4.320", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.320.tgz", - "integrity": "sha512-h70iRscrNluMZPVICXYl5SSB+rBKo22XfuIS1ER0OQxQZpKTnFpuS6coj7wY9M/3trv7OR88rRMOlKmRvDty7Q==", - "dev": true - }, - "node_modules/elliptic": { - "version": "6.5.4", - "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.5.4.tgz", - "integrity": "sha512-iLhC6ULemrljPZb+QutR5TQGB+pdW6KGD5RSegS+8sorOZT+rdQFbsQFJgvN3eRqNALqJer4oQ16YvJHlU8hzQ==", - "dev": true, - "dependencies": { - "bn.js": "^4.11.9", - "brorand": "^1.1.0", - "hash.js": "^1.0.0", - "hmac-drbg": "^1.0.1", - "inherits": "^2.0.4", - "minimalistic-assert": "^1.0.1", - "minimalistic-crypto-utils": "^1.0.1" - } - }, - "node_modules/elliptic/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, "node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", "dev": true }, - "node_modules/emojis-list": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", - "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", - "dev": true, - "engines": { - "node": ">= 4" - } - }, - "node_modules/enhanced-resolve": { - "version": "5.12.0", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz", - "integrity": "sha512-QHTXI/sZQmko1cbDoNAa3mJ5qhWUUNAq3vR0/YiD379fWQrcfuoX1+HW2S0MTt7XmoPLapdaDKUtelUSPic7hQ==", - "dev": true, - "dependencies": { - "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" - }, - "engines": { - "node": ">=10.13.0" - } - }, - "node_modules/envinfo": { - "version": "7.8.1", - "resolved": "https://registry.npmjs.org/envinfo/-/envinfo-7.8.1.tgz", - "integrity": "sha512-/o+BXHmB7ocbHEAs6F2EnG0ogybVVUdkRunTT2glZU9XAaGmhqskrvKwqXuDfNjEO0LZKWdejEEpnq8aM0tOaw==", - "dev": true, - "bin": { - "envinfo": "dist/cli.js" - }, - "engines": { - "node": ">=4" - } - }, "node_modules/error-ex": { "version": "1.3.2", "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", @@ -1947,18 +1619,19 @@ } }, "node_modules/es-abstract": { - "version": "1.21.1", - "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.21.1.tgz", - "integrity": "sha512-QudMsPOz86xYz/1dG1OuGBKOELjCh99IIWHLzy5znUB6j8xG2yMA7bfTV86VSqKF+Y/H08vQPR+9jyXpuC6hfg==", + "version": "1.22.2", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.22.2.tgz", + "integrity": "sha512-YoxfFcDmhjOgWPWsV13+2RNjq1F6UQnfs+8TftwNqtzlmFzEXvlUwdrNrYeaizfjQzRMxkZ6ElWMOJIFKdVqwA==", "dev": true, "dependencies": { + "array-buffer-byte-length": "^1.0.0", + "arraybuffer.prototype.slice": "^1.0.2", "available-typed-arrays": "^1.0.5", "call-bind": "^1.0.2", "es-set-tostringtag": "^2.0.1", "es-to-primitive": "^1.2.1", - "function-bind": "^1.1.1", - "function.prototype.name": "^1.1.5", - "get-intrinsic": "^1.1.3", + "function.prototype.name": "^1.1.6", + "get-intrinsic": "^1.2.1", "get-symbol-description": "^1.0.0", "globalthis": "^1.0.3", "gopd": "^1.0.1", @@ -1966,25 +1639,30 @@ "has-property-descriptors": "^1.0.0", "has-proto": "^1.0.1", "has-symbols": "^1.0.3", - "internal-slot": "^1.0.4", - "is-array-buffer": "^3.0.1", + "internal-slot": "^1.0.5", + "is-array-buffer": "^3.0.2", "is-callable": "^1.2.7", "is-negative-zero": "^2.0.2", "is-regex": "^1.1.4", "is-shared-array-buffer": "^1.0.2", "is-string": "^1.0.7", - "is-typed-array": "^1.1.10", + "is-typed-array": "^1.1.12", "is-weakref": "^1.0.2", - "object-inspect": "^1.12.2", + "object-inspect": "^1.12.3", "object-keys": "^1.1.1", "object.assign": "^4.1.4", - "regexp.prototype.flags": "^1.4.3", + "regexp.prototype.flags": "^1.5.1", + "safe-array-concat": "^1.0.1", "safe-regex-test": "^1.0.0", - "string.prototype.trimend": "^1.0.6", - "string.prototype.trimstart": "^1.0.6", + "string.prototype.trim": "^1.2.8", + "string.prototype.trimend": "^1.0.7", + "string.prototype.trimstart": "^1.0.7", + "typed-array-buffer": "^1.0.0", + "typed-array-byte-length": "^1.0.0", + "typed-array-byte-offset": "^1.0.0", "typed-array-length": "^1.0.4", "unbox-primitive": "^1.0.2", - "which-typed-array": "^1.1.9" + "which-typed-array": "^1.1.11" }, "engines": { "node": ">= 0.4" @@ -1993,12 +1671,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/es-module-lexer": { - "version": "0.9.3", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-0.9.3.tgz", - "integrity": "sha512-1HQ2M2sPtxwnvOvT1ZClHyQDiggdNjURWpY2we6aMKCQiUVxTmVs2UYPLIrD84sS+kMdUwfBSylbJPwNnBrnHQ==", - "dev": true - }, "node_modules/es-set-tostringtag": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.0.1.tgz", @@ -2039,11 +1711,55 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/es6-object-assign": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/es6-object-assign/-/es6-object-assign-1.1.0.tgz", - "integrity": "sha512-MEl9uirslVwqQU369iHNWZXsI8yaZYGg/D65aOgZkeyFJwHYSxilf7rQzXKI7DdDuBPrBXbfk3sl9hJhmd5AUw==", - "dev": true + "node_modules/esbuild": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz", + "integrity": "sha512-UlJ1qUUA2jL2nNib1JTSkifQTcYTroFqRjwCFW4QYEKEsixXD5Tik9xML7zh2gTxkYTBKGHNH9y7txMwVyPbjw==", + "dev": true, + "hasInstallScript": true, + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=12" + }, + "optionalDependencies": { + "@esbuild/android-arm": "0.19.3", + "@esbuild/android-arm64": "0.19.3", + "@esbuild/android-x64": "0.19.3", + "@esbuild/darwin-arm64": "0.19.3", + "@esbuild/darwin-x64": "0.19.3", + "@esbuild/freebsd-arm64": "0.19.3", + "@esbuild/freebsd-x64": "0.19.3", + "@esbuild/linux-arm": "0.19.3", + "@esbuild/linux-arm64": "0.19.3", + "@esbuild/linux-ia32": "0.19.3", + "@esbuild/linux-loong64": "0.19.3", + "@esbuild/linux-mips64el": "0.19.3", + "@esbuild/linux-ppc64": "0.19.3", + "@esbuild/linux-riscv64": "0.19.3", + "@esbuild/linux-s390x": "0.19.3", + "@esbuild/linux-x64": "0.19.3", + "@esbuild/netbsd-x64": "0.19.3", + "@esbuild/openbsd-x64": "0.19.3", + "@esbuild/sunos-x64": "0.19.3", + "@esbuild/win32-arm64": "0.19.3", + "@esbuild/win32-ia32": "0.19.3", + "@esbuild/win32-x64": "0.19.3" + } + }, + "node_modules/esbuild-plugin-polyfill-node": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/esbuild-plugin-polyfill-node/-/esbuild-plugin-polyfill-node-0.3.0.tgz", + "integrity": "sha512-SHG6CKUfWfYyYXGpW143NEZtcVVn8S/WHcEOxk62LuDXnY4Zpmc+WmxJKN6GMTgTClXJXhEM5KQlxKY6YjbucQ==", + "dev": true, + "dependencies": { + "@jspm/core": "^2.0.1", + "import-meta-resolve": "^3.0.0" + }, + "peerDependencies": { + "esbuild": "*" + } }, "node_modules/escalade": { "version": "3.1.1", @@ -2067,26 +1783,27 @@ } }, "node_modules/eslint": { - "version": "8.35.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.35.0.tgz", - "integrity": "sha512-BxAf1fVL7w+JLRQhWl2pzGeSiGqbWumV4WNvc9Rhp6tiCtm4oHnyPBSEtMGZwrQgudFQ+otqzWoPB7x+hxoWsw==", + "version": "8.51.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.51.0.tgz", + "integrity": "sha512-2WuxRZBrlwnXi+/vFSJyjMqrNjtJqiasMzehF0shoLaW7DzS3/9Yvrmq5JiT66+pNjiX4UBnLDiKHcWAr/OInA==", "dev": true, "dependencies": { - "@eslint/eslintrc": "^2.0.0", - "@eslint/js": "8.35.0", - "@humanwhocodes/config-array": "^0.11.8", + "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/regexpp": "^4.6.1", + "@eslint/eslintrc": "^2.1.2", + "@eslint/js": "8.51.0", + "@humanwhocodes/config-array": "^0.11.11", "@humanwhocodes/module-importer": "^1.0.1", "@nodelib/fs.walk": "^1.2.8", - "ajv": "^6.10.0", + "ajv": "^6.12.4", "chalk": "^4.0.0", "cross-spawn": "^7.0.2", "debug": "^4.3.2", "doctrine": "^3.0.0", "escape-string-regexp": "^4.0.0", - "eslint-scope": "^7.1.1", - "eslint-utils": "^3.0.0", - "eslint-visitor-keys": "^3.3.0", - "espree": "^9.4.0", + "eslint-scope": "^7.2.2", + "eslint-visitor-keys": "^3.4.3", + "espree": "^9.6.1", "esquery": "^1.4.2", "esutils": "^2.0.2", "fast-deep-equal": "^3.1.3", @@ -2094,23 +1811,19 @@ "find-up": "^5.0.0", "glob-parent": "^6.0.2", "globals": "^13.19.0", - "grapheme-splitter": "^1.0.4", + "graphemer": "^1.4.0", "ignore": "^5.2.0", - "import-fresh": "^3.0.0", "imurmurhash": "^0.1.4", "is-glob": "^4.0.0", "is-path-inside": "^3.0.3", - "js-sdsl": "^4.1.4", "js-yaml": "^4.1.0", "json-stable-stringify-without-jsonify": "^1.0.1", "levn": "^0.4.1", "lodash.merge": "^4.6.2", "minimatch": "^3.1.2", "natural-compare": "^1.4.0", - "optionator": "^0.9.1", - "regexpp": "^3.2.0", + "optionator": "^0.9.3", "strip-ansi": "^6.0.1", - "strip-json-comments": "^3.1.0", "text-table": "^0.2.0" }, "bin": { @@ -2144,9 +1857,9 @@ } }, "node_modules/eslint-module-utils": { - "version": "2.7.4", - "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.7.4.tgz", - "integrity": "sha512-j4GT+rqzCoRKHwURX7pddtIPGySnX9Si/cgMI5ztrcqOPtk5dDEeZ34CQVPphnqkJytlc97Vuk05Um2mJ3gEQA==", + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.8.0.tgz", + "integrity": "sha512-aWajIYfsqCKRDgUfjEXNN/JlrzauMuSEy5sbd7WXbtW3EH6A6MpwEh42c7qD+MqQo9QMJ6fWLAeIJynx0g6OAw==", "dev": true, "dependencies": { "debug": "^3.2.7" @@ -2179,26 +1892,28 @@ } }, "node_modules/eslint-plugin-import": { - "version": "2.27.5", - "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.27.5.tgz", - "integrity": "sha512-LmEt3GVofgiGuiE+ORpnvP+kAm3h6MLZJ4Q5HCyHADofsb4VzXFsRiWj3c0OFiV+3DWFh0qg3v9gcPlfc3zRow==", + "version": "2.28.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.28.1.tgz", + "integrity": "sha512-9I9hFlITvOV55alzoKBI+K9q74kv0iKMeY6av5+umsNwayt59fz692daGyjR+oStBQgx6nwR9rXldDev3Clw+A==", "dev": true, "dependencies": { "array-includes": "^3.1.6", + "array.prototype.findlastindex": "^1.2.2", "array.prototype.flat": "^1.3.1", "array.prototype.flatmap": "^1.3.1", "debug": "^3.2.7", "doctrine": "^2.1.0", "eslint-import-resolver-node": "^0.3.7", - "eslint-module-utils": "^2.7.4", + "eslint-module-utils": "^2.8.0", "has": "^1.0.3", - "is-core-module": "^2.11.0", + "is-core-module": "^2.13.0", "is-glob": "^4.0.3", "minimatch": "^3.1.2", + "object.fromentries": "^2.0.6", + "object.groupby": "^1.0.0", "object.values": "^1.1.6", - "resolve": "^1.22.1", - "semver": "^6.3.0", - "tsconfig-paths": "^3.14.1" + "semver": "^6.3.1", + "tsconfig-paths": "^3.14.2" }, "engines": { "node": ">=4" @@ -2229,30 +1944,32 @@ } }, "node_modules/eslint-plugin-import/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" } }, "node_modules/eslint-plugin-jsdoc": { - "version": "40.0.1", - "resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-40.0.1.tgz", - "integrity": "sha512-KkiRInury7YrjjV5aCHDxwsPy6XFt5p2b2CnpDMITnWs8patNPf5kj24+VXIWw45kP6z/B0GOKfrYczB56OjQQ==", + "version": "46.8.2", + "resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-46.8.2.tgz", + "integrity": "sha512-5TSnD018f3tUJNne4s4gDWQflbsgOycIKEUBoCLn6XtBMgNHxQFmV8vVxUtiPxAQq8lrX85OaSG/2gnctxw9uQ==", "dev": true, "dependencies": { - "@es-joy/jsdoccomment": "~0.36.1", - "comment-parser": "1.3.1", + "@es-joy/jsdoccomment": "~0.40.1", + "are-docs-informative": "^0.0.2", + "comment-parser": "1.4.0", "debug": "^4.3.4", "escape-string-regexp": "^4.0.0", - "esquery": "^1.4.0", - "semver": "^7.3.8", + "esquery": "^1.5.0", + "is-builtin-module": "^3.2.1", + "semver": "^7.5.4", "spdx-expression-parse": "^3.0.1" }, "engines": { - "node": "^14 || ^16 || ^17 || ^18 || ^19" + "node": ">=16" }, "peerDependencies": { "eslint": "^7.0.0 || ^8.0.0" @@ -2268,118 +1985,74 @@ } }, "node_modules/eslint-plugin-unicorn": { - "version": "46.0.0", - "resolved": "https://registry.npmjs.org/eslint-plugin-unicorn/-/eslint-plugin-unicorn-46.0.0.tgz", - "integrity": "sha512-j07WkC+PFZwk8J33LYp6JMoHa1lXc1u6R45pbSAipjpfpb7KIGr17VE2D685zCxR5VL4cjrl65kTJflziQWMDA==", + "version": "48.0.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-unicorn/-/eslint-plugin-unicorn-48.0.1.tgz", + "integrity": "sha512-FW+4r20myG/DqFcCSzoumaddKBicIPeFnTrifon2mWIzlfyvzwyqZjqVP7m4Cqr/ZYisS2aiLghkUWaPg6vtCw==", "dev": true, "dependencies": { - "@babel/helper-validator-identifier": "^7.19.1", - "@eslint-community/eslint-utils": "^4.1.2", - "ci-info": "^3.6.1", + "@babel/helper-validator-identifier": "^7.22.5", + "@eslint-community/eslint-utils": "^4.4.0", + "ci-info": "^3.8.0", "clean-regexp": "^1.0.0", - "esquery": "^1.4.0", + "esquery": "^1.5.0", "indent-string": "^4.0.0", - "is-builtin-module": "^3.2.0", + "is-builtin-module": "^3.2.1", "jsesc": "^3.0.2", "lodash": "^4.17.21", "pluralize": "^8.0.0", "read-pkg-up": "^7.0.1", - "regexp-tree": "^0.1.24", - "regjsparser": "^0.9.1", - "safe-regex": "^2.1.1", - "semver": "^7.3.8", + "regexp-tree": "^0.1.27", + "regjsparser": "^0.10.0", + "semver": "^7.5.4", "strip-indent": "^3.0.0" }, "engines": { - "node": ">=14.18" + "node": ">=16" }, "funding": { "url": "https://github.com/sindresorhus/eslint-plugin-unicorn?sponsor=1" }, "peerDependencies": { - "eslint": ">=8.28.0" + "eslint": ">=8.44.0" } }, "node_modules/eslint-scope": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", - "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "version": "7.2.2", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", + "integrity": "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg==", "dev": true, "dependencies": { "esrecurse": "^4.3.0", - "estraverse": "^4.1.1" - }, - "engines": { - "node": ">=8.0.0" - } - }, - "node_modules/eslint-utils": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-3.0.0.tgz", - "integrity": "sha512-uuQC43IGctw68pJA1RgbQS8/NP7rch6Cwd4j3ZBtgo4/8Flj4eGE7ZYSZRN3iq5pVUv6GPdW5Z1RFleo84uLDA==", - "dev": true, - "dependencies": { - "eslint-visitor-keys": "^2.0.0" + "estraverse": "^5.2.0" }, "engines": { - "node": "^10.0.0 || ^12.0.0 || >= 14.0.0" + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" }, "funding": { - "url": "https://github.com/sponsors/mysticatea" - }, - "peerDependencies": { - "eslint": ">=5" - } - }, - "node_modules/eslint-utils/node_modules/eslint-visitor-keys": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-2.1.0.tgz", - "integrity": "sha512-0rSmRBzXgDzIsD6mGdJgevzgezI534Cer5L/vyMX0kHzT/jiB43jRhd9YUlMGYLQy2zprNmoT8qasCGtY+QaKw==", - "dev": true, - "engines": { - "node": ">=10" + "url": "https://opencollective.com/eslint" } }, "node_modules/eslint-visitor-keys": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz", - "integrity": "sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA==", + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", "dev": true, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - } - }, - "node_modules/eslint/node_modules/eslint-scope": { - "version": "7.1.1", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.1.1.tgz", - "integrity": "sha512-QKQM/UXpIiHcLqJ5AOyIW7XZmzjkzQXYE54n1++wb0u9V/abW3l9uQnxX8Z5Xd18xyKIMTUAyQ0k1e8pz6LUrw==", - "dev": true, - "dependencies": { - "esrecurse": "^4.3.0", - "estraverse": "^5.2.0" }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - } - }, - "node_modules/eslint/node_modules/estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", - "dev": true, - "engines": { - "node": ">=4.0" + "funding": { + "url": "https://opencollective.com/eslint" } }, "node_modules/espree": { - "version": "9.4.1", - "resolved": "https://registry.npmjs.org/espree/-/espree-9.4.1.tgz", - "integrity": "sha512-XwctdmTO6SIvCzd9810yyNzIrOrqNYV9Koizx4C/mRhf9uq0o4yHoCEU/670pOxOL/MSraektvSAji79kX90Vg==", + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", "dev": true, "dependencies": { - "acorn": "^8.8.0", + "acorn": "^8.9.0", "acorn-jsx": "^5.3.2", - "eslint-visitor-keys": "^3.3.0" + "eslint-visitor-keys": "^3.4.1" }, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" @@ -2400,15 +2073,6 @@ "node": ">=0.10" } }, - "node_modules/esquery/node_modules/estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", - "dev": true, - "engines": { - "node": ">=4.0" - } - }, "node_modules/esrecurse": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", @@ -2421,7 +2085,7 @@ "node": ">=4.0" } }, - "node_modules/esrecurse/node_modules/estraverse": { + "node_modules/estraverse": { "version": "5.3.0", "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", @@ -2430,15 +2094,6 @@ "node": ">=4.0" } }, - "node_modules/estraverse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", - "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", - "dev": true, - "engines": { - "node": ">=4.0" - } - }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -2466,16 +2121,6 @@ "node": ">=0.8.x" } }, - "node_modules/evp_bytestokey": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz", - "integrity": "sha512-/f2Go4TognH/KvCISP7OUsHn85hT9nUkxxA9BEWxFn+Oj9o8ZNLm/40hdlgSLyuOimsrTKLUMEorQexp/aPQeA==", - "dev": true, - "dependencies": { - "md5.js": "^1.3.4", - "safe-buffer": "^5.1.1" - } - }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -2483,9 +2128,9 @@ "dev": true }, "node_modules/fast-glob": { - "version": "3.2.12", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", - "integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", + "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", "dev": true, "dependencies": { "@nodelib/fs.stat": "^2.0.2", @@ -2522,15 +2167,6 @@ "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", "dev": true }, - "node_modules/fastest-levenshtein": { - "version": "1.0.16", - "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", - "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", - "dev": true, - "engines": { - "node": ">= 4.9.1" - } - }, "node_modules/fastq": { "version": "1.15.0", "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", @@ -2564,15 +2200,6 @@ "node": ">=8" } }, - "node_modules/filter-obj": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/filter-obj/-/filter-obj-2.0.2.tgz", - "integrity": "sha512-lO3ttPjHZRfjMcxWKb1j1eDhTFsu4meeR3lnMcnBFhk6RuLhvEiuALu2TlfL310ph4lCYYwgF/ElIjdP739tdg==", - "dev": true, - "engines": { - "node": ">=8" - } - }, "node_modules/find-up": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", @@ -2627,9 +2254,9 @@ } }, "node_modules/fs-extra": { - "version": "11.1.0", - "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.1.0.tgz", - "integrity": "sha512-0rcTq621PD5jM/e0a3EJoGC/1TC5ZBCERW82LQuwfGnCa1V8w7dpYH1yNu+SLb6E5dkeCBzKEyLGlFrnr+dUyw==", + "version": "11.1.1", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.1.1.tgz", + "integrity": "sha512-MGIE4HOvQCeUCzmlHs0vXpih4ysz4wg9qiSAu6cd42lVwPbTM1TjV7RusoyQqMmk/95gdQZX72u+YW+c3eEpFQ==", "dev": true, "dependencies": { "graceful-fs": "^4.2.0", @@ -2667,15 +2294,15 @@ "dev": true }, "node_modules/function.prototype.name": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.5.tgz", - "integrity": "sha512-uN7m/BzVKQnCUF/iW8jYea67v++2u7m5UgENbHRtdDVclOUP+FMPlCNdmk0h/ysGyo2tavMJEDqJAkJdRa1vMA==", + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.6.tgz", + "integrity": "sha512-Z5kx79swU5P27WEayXM1tBi5Ze/lbIyiNgU3qyXUOf9b2rgXYyF9Dy9Cx+IQv/Lc8WCG6L82zwUPpSS9hGehIg==", "dev": true, "dependencies": { "call-bind": "^1.0.2", - "define-properties": "^1.1.3", - "es-abstract": "^1.19.0", - "functions-have-names": "^1.2.2" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "functions-have-names": "^1.2.3" }, "engines": { "node": ">= 0.4" @@ -2722,13 +2349,14 @@ } }, "node_modules/get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", + "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", "dev": true, "dependencies": { "function-bind": "^1.1.1", "has": "^1.0.3", + "has-proto": "^1.0.1", "has-symbols": "^1.0.3" }, "funding": { @@ -2783,16 +2411,10 @@ "node": ">=10.13.0" } }, - "node_modules/glob-to-regexp": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", - "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", - "dev": true - }, "node_modules/globals": { - "version": "13.20.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-13.20.0.tgz", - "integrity": "sha512-Qg5QtVkCy/kv3FUSlu4ukeZDVf9ee0iXLAUYX13gbR17bnejFTzr4iS9bY7kwCf1NztRNm1t91fjOiyx4CSwPQ==", + "version": "13.23.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.23.0.tgz", + "integrity": "sha512-XAmF0RjlrjY23MA51q3HltdlGxUpXPvg0GioKiD9X6HD28iMjo2dKC8Vqwm7lne4GNr78+RHTfliktR6ZH09wA==", "dev": true, "dependencies": { "type-fest": "^0.20.2" @@ -2857,27 +2479,12 @@ "integrity": "sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==", "dev": true }, - "node_modules/grapheme-splitter": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz", - "integrity": "sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==", + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", "dev": true }, - "node_modules/gzip-size": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz", - "integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==", - "dev": true, - "dependencies": { - "duplexer": "^0.1.2" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/has": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", @@ -2945,84 +2552,26 @@ } }, "node_modules/has-tostringtag": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.0.tgz", - "integrity": "sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ==", - "dev": true, - "dependencies": { - "has-symbols": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-unicode": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", - "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", - "dev": true - }, - "node_modules/hash-base": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/hash-base/-/hash-base-3.1.0.tgz", - "integrity": "sha512-1nmYp/rhMDiE7AYkDw+lLwlAzz0AntGIe51F3RfFfEqyQ3feY2eI/NcwC6umIQVOASPMsWJLJScWKSSvzL9IVA==", - "dev": true, - "dependencies": { - "inherits": "^2.0.4", - "readable-stream": "^3.6.0", - "safe-buffer": "^5.2.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/hash-base/node_modules/readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/hash-base/node_modules/safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, - "node_modules/hash.js": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/hash.js/-/hash.js-1.1.7.tgz", - "integrity": "sha512-taOaskGt4z4SOANNseOviYDvjEJinIkRgmp7LbKP2YTTmVxWBl87s/uzK9r+44BclBSp2X7K1hqeNfz9JbBeXA==", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.0.tgz", + "integrity": "sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ==", "dev": true, "dependencies": { - "inherits": "^2.0.3", - "minimalistic-assert": "^1.0.1" + "has-symbols": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/has-unicode": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", + "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", + "dev": true + }, "node_modules/he": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz", @@ -3032,29 +2581,12 @@ "he": "bin/he" } }, - "node_modules/hmac-drbg": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/hmac-drbg/-/hmac-drbg-1.0.1.tgz", - "integrity": "sha512-Tti3gMqLdZfhOQY1Mzf/AanLiqh1WTiJgEj26ZuYQ9fbkLomzGchCws4FyrSd4VkpBfiNhaE1On+lOz894jvXg==", - "dev": true, - "dependencies": { - "hash.js": "^1.0.3", - "minimalistic-assert": "^1.0.0", - "minimalistic-crypto-utils": "^1.0.1" - } - }, "node_modules/hosted-git-info": { "version": "2.8.9", "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", "dev": true }, - "node_modules/https-browserify": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/https-browserify/-/https-browserify-1.0.0.tgz", - "integrity": "sha512-J+FkSdyD+0mA0N+81tMotaRMfSL9SGi+xpD3T6YApKsc3bGSXJlfXri3VyFOeYkfLRQisDk1W+jIFFKBeUBbBg==", - "dev": true - }, "node_modules/ieee754": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz", @@ -3106,23 +2638,14 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/import-local": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz", - "integrity": "sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg==", + "node_modules/import-meta-resolve": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/import-meta-resolve/-/import-meta-resolve-3.0.0.tgz", + "integrity": "sha512-4IwhLhNNA8yy445rPjD/lWh++7hMDOml2eHtd58eG7h+qK3EryMuuRbsHGPikCoAgIkkDnckKfWSk2iDla/ejg==", "dev": true, - "dependencies": { - "pkg-dir": "^4.2.0", - "resolve-cwd": "^3.0.0" - }, - "bin": { - "import-local-fixture": "fixtures/cli.js" - }, - "engines": { - "node": ">=8" - }, "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "type": "github", + "url": "https://github.com/sponsors/wooorm" } }, "node_modules/imurmurhash": { @@ -3173,31 +2696,6 @@ "node": ">= 0.4" } }, - "node_modules/interpret": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/interpret/-/interpret-3.1.1.tgz", - "integrity": "sha512-6xwYfHbajpoF0xLW+iwLkhwgvLoZDfjYfoFNu8ftMoXINzwuymNLd9u/KmwtdT2GbR+/Cz66otEGEVVUHX9QLQ==", - "dev": true, - "engines": { - "node": ">=10.13.0" - } - }, - "node_modules/is-arguments": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/is-arguments/-/is-arguments-1.1.1.tgz", - "integrity": "sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA==", - "dev": true, - "dependencies": { - "call-bind": "^1.0.2", - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/is-array-buffer": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.2.tgz", @@ -3286,9 +2784,9 @@ } }, "node_modules/is-core-module": { - "version": "2.11.0", - "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.11.0.tgz", - "integrity": "sha512-RRjxlvLDkD1YJwDbroBHMb+cukurkDWNyHx7D3oNB5x9rb5ogcksMC5wHCadcXoo67gVr/+3GFySh3134zi6rw==", + "version": "2.13.0", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.0.tgz", + "integrity": "sha512-Z7dk6Qo8pOCp3l4tsX2C5ZVas4V+UxwQodwZhLopL91TX8UyyHEXafPcyoeeWuLrwzHcr3igO78wNLwHJHsMCQ==", "dev": true, "dependencies": { "has": "^1.0.3" @@ -3330,21 +2828,6 @@ "node": ">=8" } }, - "node_modules/is-generator-function": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/is-generator-function/-/is-generator-function-1.0.10.tgz", - "integrity": "sha512-jsEjy9l3yiXEQ+PsXdmBwEPcOxaXWLspKdplFUVI9vq1iZgIekeC0L167qeu86czQaxed3q/Uzuw0swL0irL8A==", - "dev": true, - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -3357,22 +2840,6 @@ "node": ">=0.10.0" } }, - "node_modules/is-nan": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/is-nan/-/is-nan-1.3.2.tgz", - "integrity": "sha512-E+zBKpQ2t6MEo1VsonYmluk9NxGrbzpeeLC2xIViuO2EjU2xsXsBPwTr3Ykv9l08UYEVEdWeRZNouaZqF6RN0w==", - "dev": true, - "dependencies": { - "call-bind": "^1.0.0", - "define-properties": "^1.1.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/is-negative-zero": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.2.tgz", @@ -3427,18 +2894,6 @@ "node": ">=8" } }, - "node_modules/is-plain-object": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", - "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", - "dev": true, - "dependencies": { - "isobject": "^3.0.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/is-regex": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.1.4.tgz", @@ -3498,16 +2953,12 @@ } }, "node_modules/is-typed-array": { - "version": "1.1.10", - "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.10.tgz", - "integrity": "sha512-PJqgEHiWZvMpaFZ3uTc8kHPM4+4ADTlDniuQL7cU/UDA0Ql7F70yGfHph3cLNe+c9toaigv+DFzTJKhc2CtO6A==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.12.tgz", + "integrity": "sha512-Z14TF2JNG8Lss5/HMqt0//T9JeHXttXy5pH/DBU4vi98ozO2btxzq9MwYDZYnKwU8nRsz/+GVFVRDq3DkVuSPg==", "dev": true, "dependencies": { - "available-typed-arrays": "^1.0.5", - "call-bind": "^1.0.2", - "for-each": "^0.3.3", - "gopd": "^1.0.1", - "has-tostringtag": "^1.0.0" + "which-typed-array": "^1.1.11" }, "engines": { "node": ">= 0.4" @@ -3552,54 +3003,6 @@ "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", "dev": true }, - "node_modules/isobject": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", - "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/jest-worker": { - "version": "27.5.1", - "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-27.5.1.tgz", - "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", - "dev": true, - "dependencies": { - "@types/node": "*", - "merge-stream": "^2.0.0", - "supports-color": "^8.0.0" - }, - "engines": { - "node": ">= 10.13.0" - } - }, - "node_modules/jest-worker/node_modules/supports-color": { - "version": "8.1.1", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", - "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/supports-color?sponsor=1" - } - }, - "node_modules/js-sdsl": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/js-sdsl/-/js-sdsl-4.3.0.tgz", - "integrity": "sha512-mifzlm2+5nZ+lEcLJMoBK0/IH/bDg8XnJfd/Wq6IP+xoCjLZsTOnV2QpxlVbX9bMnkl5PdEjNtBJ9Cj1NjifhQ==", - "dev": true, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/js-sdsl" - } - }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -3619,9 +3022,9 @@ } }, "node_modules/jsdoc-type-pratt-parser": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-3.1.0.tgz", - "integrity": "sha512-MgtD0ZiCDk9B+eI73BextfRrVQl0oyzRG8B2BjORts6jbunj4ScKPcyXGTbB6eXL4y9TzxCm6hyeLq/2ASzNdw==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-4.0.0.tgz", + "integrity": "sha512-YtOli5Cmzy3q4dP26GraSOeAhqecewG04hoO8DY56CH4KJ9Fvv5qKWUCCo3HZob7esJQHCv6/+bnTy72xZZaVQ==", "dev": true, "engines": { "node": ">=12.0.0" @@ -3693,15 +3096,6 @@ "setimmediate": "^1.0.5" } }, - "node_modules/kind-of": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", - "integrity": "sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/levn": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", @@ -3730,41 +3124,6 @@ "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", "dev": true }, - "node_modules/loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", - "dev": true, - "engines": { - "node": ">=6.11.5" - } - }, - "node_modules/loader-utils": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", - "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", - "dev": true, - "dependencies": { - "big.js": "^5.2.2", - "emojis-list": "^3.0.0", - "json5": "^2.1.2" - }, - "engines": { - "node": ">=8.9.0" - } - }, - "node_modules/loader-utils/node_modules/json5": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", - "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", - "dev": true, - "bin": { - "json5": "lib/cli.js" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/locate-path": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", @@ -3820,23 +3179,6 @@ "node": ">=10" } }, - "node_modules/md5.js": { - "version": "1.3.5", - "resolved": "https://registry.npmjs.org/md5.js/-/md5.js-1.3.5.tgz", - "integrity": "sha512-xitP+WxNPcTTOgnTJcrhM0xvdPepipPSf3I8EIpGKeFLjt3PlJLIDG3u8EX53ZIubkb+5U2+3rELYpEhHhzdkg==", - "dev": true, - "dependencies": { - "hash-base": "^3.0.0", - "inherits": "^2.0.1", - "safe-buffer": "^5.1.2" - } - }, - "node_modules/merge-stream": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", - "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", - "dev": true - }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -3859,46 +3201,6 @@ "node": ">=8.6" } }, - "node_modules/miller-rabin": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/miller-rabin/-/miller-rabin-4.0.1.tgz", - "integrity": "sha512-115fLhvZVqWwHPbClyntxEVfVDfl9DLLTuJvq3g2O/Oxi8AiNouAHvDSzHS0viUJc+V5vm3eq91Xwqn9dp4jRA==", - "dev": true, - "dependencies": { - "bn.js": "^4.0.0", - "brorand": "^1.0.1" - }, - "bin": { - "miller-rabin": "bin/miller-rabin" - } - }, - "node_modules/miller-rabin/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, - "node_modules/mime-db": { - "version": "1.52.0", - "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", - "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", - "dev": true, - "engines": { - "node": ">= 0.6" - } - }, - "node_modules/mime-types": { - "version": "2.1.35", - "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", - "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", - "dev": true, - "dependencies": { - "mime-db": "1.52.0" - }, - "engines": { - "node": ">= 0.6" - } - }, "node_modules/min-indent": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz", @@ -3908,18 +3210,6 @@ "node": ">=4" } }, - "node_modules/minimalistic-assert": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", - "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", - "dev": true - }, - "node_modules/minimalistic-crypto-utils": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/minimalistic-crypto-utils/-/minimalistic-crypto-utils-1.0.1.tgz", - "integrity": "sha512-JIYlbt6g8i5jKfJ3xz7rF0LXmv2TkDxBLUkiBeZ7bAx4GnnNMr8xFpGnOxn6GhTEHx3SjRrZEoU+j04prX1ktg==", - "dev": true - }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -4064,15 +3354,6 @@ "url": "https://github.com/chalk/supports-color?sponsor=1" } }, - "node_modules/mrmime": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz", - "integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==", - "dev": true, - "engines": { - "node": ">=10" - } - }, "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", @@ -4080,134 +3361,21 @@ "dev": true }, "node_modules/nanoid": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.3.tgz", - "integrity": "sha512-p1sjXuopFs0xg+fPASzQ28agW1oHD7xDsd9Xkf3T15H3c/cifrFHVwrh74PdoklAPi+i7MdRsE47vm2r6JoB+w==", - "dev": true, - "bin": { - "nanoid": "bin/nanoid.cjs" - }, - "engines": { - "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" - } - }, - "node_modules/natural-compare": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", - "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", - "dev": true - }, - "node_modules/natural-compare-lite": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/natural-compare-lite/-/natural-compare-lite-1.4.0.tgz", - "integrity": "sha512-Tj+HTDSJJKaZnfiuw+iaF9skdPpTo2GtEly5JHnWV/hfv2Qj/9RKsGISQtLh2ox3l5EAGw487hnBee0sIJ6v2g==", - "dev": true - }, - "node_modules/neo-async": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", - "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", - "dev": true - }, - "node_modules/node-polyfill-webpack-plugin": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/node-polyfill-webpack-plugin/-/node-polyfill-webpack-plugin-2.0.1.tgz", - "integrity": "sha512-ZUMiCnZkP1LF0Th2caY6J/eKKoA0TefpoVa68m/LQU1I/mE8rGt4fNYGgNuCcK+aG8P8P43nbeJ2RqJMOL/Y1A==", - "dev": true, - "dependencies": { - "assert": "^2.0.0", - "browserify-zlib": "^0.2.0", - "buffer": "^6.0.3", - "console-browserify": "^1.2.0", - "constants-browserify": "^1.0.0", - "crypto-browserify": "^3.12.0", - "domain-browser": "^4.22.0", - "events": "^3.3.0", - "filter-obj": "^2.0.2", - "https-browserify": "^1.0.0", - "os-browserify": "^0.3.0", - "path-browserify": "^1.0.1", - "process": "^0.11.10", - "punycode": "^2.1.1", - "querystring-es3": "^0.2.1", - "readable-stream": "^4.0.0", - "stream-browserify": "^3.0.0", - "stream-http": "^3.2.0", - "string_decoder": "^1.3.0", - "timers-browserify": "^2.0.12", - "tty-browserify": "^0.0.1", - "type-fest": "^2.14.0", - "url": "^0.11.0", - "util": "^0.12.4", - "vm-browserify": "^1.1.2" - }, - "engines": { - "node": ">=12" - }, - "peerDependencies": { - "webpack": ">=5" - } - }, - "node_modules/node-polyfill-webpack-plugin/node_modules/readable-stream": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz", - "integrity": "sha512-MuEnA0lbSi7JS8XM+WNJlWZkHAAdm7gETHdFK//Q/mChGyj2akEFtdLZh32jSdkWGbRwCW9pn6g3LWDdDeZnBQ==", - "dev": true, - "dependencies": { - "abort-controller": "^3.0.0", - "buffer": "^6.0.3", - "events": "^3.3.0", - "process": "^0.11.10" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - } - }, - "node_modules/node-polyfill-webpack-plugin/node_modules/safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, - "node_modules/node-polyfill-webpack-plugin/node_modules/string_decoder": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", - "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", - "dev": true, - "dependencies": { - "safe-buffer": "~5.2.0" - } - }, - "node_modules/node-polyfill-webpack-plugin/node_modules/type-fest": { - "version": "2.19.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", - "integrity": "sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.3.tgz", + "integrity": "sha512-p1sjXuopFs0xg+fPASzQ28agW1oHD7xDsd9Xkf3T15H3c/cifrFHVwrh74PdoklAPi+i7MdRsE47vm2r6JoB+w==", "dev": true, - "engines": { - "node": ">=12.20" + "bin": { + "nanoid": "bin/nanoid.cjs" }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" } }, - "node_modules/node-releases": { - "version": "2.0.10", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.10.tgz", - "integrity": "sha512-5GFldHPXVG/YZmFzJvKK2zDSzPKhEp0+ZR5SVaoSag9fsL5YgHbUHDfnG5494ISANDcK4KwPXAx2xqVEydmd7w==", + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", "dev": true }, "node_modules/normalize-package-data": { @@ -4223,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -4264,22 +3432,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/object-is": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/object-is/-/object-is-1.1.5.tgz", - "integrity": "sha512-3cyDsyHgtmi7I7DfSSI2LDp6SK2lwvtbg0p0R1e0RvTqF5ceGx+K2dfSjm1bKDMVCFEDAQvy+o8c6a7VujOddw==", - "dev": true, - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.1.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/object-keys": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", @@ -4307,6 +3459,35 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/object.fromentries": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/object.fromentries/-/object.fromentries-2.0.7.tgz", + "integrity": "sha512-UPbPHML6sL8PI/mOqPwsH4G6iyXcCGzLin8KvEPenOZN5lpCNBZZQ+V62vdjB1mQHrmqGQt5/OJzemUA+KJmEA==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object.groupby": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/object.groupby/-/object.groupby-1.0.1.tgz", + "integrity": "sha512-HqaQtqLnp/8Bn4GL16cj+CUYbnpe1bh0TtEaWvybszDG4tgxCJuRpV8VGuvNaI1fAnI4lUJzDG55MXcOH4JZcQ==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "get-intrinsic": "^1.2.1" + } + }, "node_modules/object.values": { "version": "1.1.6", "resolved": "https://registry.npmjs.org/object.values/-/object.values-1.1.6.tgz", @@ -4333,38 +3514,23 @@ "wrappy": "1" } }, - "node_modules/opener": { - "version": "1.5.2", - "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", - "integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==", - "dev": true, - "bin": { - "opener": "bin/opener-bin.js" - } - }, "node_modules/optionator": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", - "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", + "integrity": "sha512-JjCoypp+jKn1ttEFExxhetCKeJt9zhAgAve5FXHixTvFDW/5aEktX9bufBKLRRMdU7bNtpLfcGu94B3cdEJgjg==", "dev": true, "dependencies": { + "@aashutoshrathi/word-wrap": "^1.2.3", "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", - "type-check": "^0.4.0", - "word-wrap": "^1.2.3" + "type-check": "^0.4.0" }, "engines": { "node": ">= 0.8.0" } }, - "node_modules/os-browserify": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/os-browserify/-/os-browserify-0.3.0.tgz", - "integrity": "sha512-gjcpUc3clBf9+210TRaDWbf+rZZZEshZ+DlXMRCeAjp0xhTrnQsKHypIy1J3d5hKdUzj69t708EHtU8P6bUn0A==", - "dev": true - }, "node_modules/p-limit": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", @@ -4422,19 +3588,6 @@ "node": ">=6" } }, - "node_modules/parse-asn1": { - "version": "5.1.6", - "resolved": "https://registry.npmjs.org/parse-asn1/-/parse-asn1-5.1.6.tgz", - "integrity": "sha512-RnZRo1EPU6JBnra2vGHj0yhp6ebyjBZpmUCLHWiFhxlzvBCCpAuZ7elsBp1PVAbQN0/04VD/19rfzlBSwLstMw==", - "dev": true, - "dependencies": { - "asn1.js": "^5.2.0", - "browserify-aes": "^1.0.0", - "evp_bytestokey": "^1.0.0", - "pbkdf2": "^3.0.3", - "safe-buffer": "^5.1.1" - } - }, "node_modules/parse-json": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-5.2.0.tgz", @@ -4453,12 +3606,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/path-browserify": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz", - "integrity": "sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g==", - "dev": true - }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -4501,28 +3648,6 @@ "node": ">=8" } }, - "node_modules/pbkdf2": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/pbkdf2/-/pbkdf2-3.1.2.tgz", - "integrity": "sha512-iuh7L6jA7JEGu2WxDwtQP1ddOpaJNC4KlDEFfdQajSGgGPNi4OyDc2R7QnbY2bR9QjBVGwgvTdNJZoE7RaxUMA==", - "dev": true, - "dependencies": { - "create-hash": "^1.1.2", - "create-hmac": "^1.1.4", - "ripemd160": "^2.0.1", - "safe-buffer": "^5.0.1", - "sha.js": "^2.4.8" - }, - "engines": { - "node": ">=0.12" - } - }, - "node_modules/picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", - "dev": true - }, "node_modules/picomatch": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", @@ -4535,70 +3660,6 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, - "node_modules/pkg-dir": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", - "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", - "dev": true, - "dependencies": { - "find-up": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/pkg-dir/node_modules/find-up": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", - "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", - "dev": true, - "dependencies": { - "locate-path": "^5.0.0", - "path-exists": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/pkg-dir/node_modules/locate-path": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", - "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", - "dev": true, - "dependencies": { - "p-locate": "^4.1.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/pkg-dir/node_modules/p-limit": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", - "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", - "dev": true, - "dependencies": { - "p-try": "^2.0.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/pkg-dir/node_modules/p-locate": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", - "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", - "dev": true, - "dependencies": { - "p-limit": "^2.2.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/pluralize": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/pluralize/-/pluralize-8.0.0.tgz", @@ -4618,9 +3679,9 @@ } }, "node_modules/prettier": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.0.tgz", - "integrity": "sha512-zBf5eHpwHOGPC47h0zrPyNn+eAEIdEzfywMoYn2XPi0P44Zp0tSq64rq0xAREh4auw2cJZHo9QUob+NqCQky4g==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", + "integrity": "sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==", "dev": true, "bin": { "prettier": "bin/prettier.cjs" @@ -4647,26 +3708,6 @@ "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", "dev": true }, - "node_modules/public-encrypt": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/public-encrypt/-/public-encrypt-4.0.3.tgz", - "integrity": "sha512-zVpa8oKZSz5bTMTFClc1fQOnyyEzpl5ozpi1B5YcvBrdohMjH2rfsBtyXcuNuwjsDIXmBYlF2N5FlJYhR29t8Q==", - "dev": true, - "dependencies": { - "bn.js": "^4.1.0", - "browserify-rsa": "^4.0.0", - "create-hash": "^1.1.0", - "parse-asn1": "^5.0.0", - "randombytes": "^2.0.1", - "safe-buffer": "^5.1.2" - } - }, - "node_modules/public-encrypt/node_modules/bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - }, "node_modules/punycode": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", @@ -4676,25 +3717,6 @@ "node": ">=6" } }, - "node_modules/querystring": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/querystring/-/querystring-0.2.0.tgz", - "integrity": "sha512-X/xY82scca2tau62i9mDyU9K+I+djTMUsvwf7xnUX5GLvVzgJybOJf4Y6o9Zx3oJK/LSXg5tTZBjwzqVPaPO2g==", - "deprecated": "The querystring API is considered Legacy. new code should use the URLSearchParams API instead.", - "dev": true, - "engines": { - "node": ">=0.4.x" - } - }, - "node_modules/querystring-es3": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/querystring-es3/-/querystring-es3-0.2.1.tgz", - "integrity": "sha512-773xhDQnZBMFobEiztv8LIl70ch5MSF/jUQVlhwFyBILqq96anmoctVIYz+ZRp0qbCKATTn6ev02M3r7Ga5vqA==", - "dev": true, - "engines": { - "node": ">=0.4.x" - } - }, "node_modules/queue-microtask": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", @@ -4724,16 +3746,6 @@ "safe-buffer": "^5.1.0" } }, - "node_modules/randomfill": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/randomfill/-/randomfill-1.0.4.tgz", - "integrity": "sha512-87lcbR8+MhcWcUiQ+9e+Rwx8MyR2P7qnt15ynUlbm3TU/fjbgz4GsvfSUDTemtCCtVCqb4ZcEFlyPNTh9bBTLw==", - "dev": true, - "dependencies": { - "randombytes": "^2.0.5", - "safe-buffer": "^5.1.0" - } - }, "node_modules/read-pkg": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-5.2.0.tgz", @@ -4863,36 +3875,24 @@ "node": ">=8.10.0" } }, - "node_modules/rechoir": { - "version": "0.8.0", - "resolved": "https://registry.npmjs.org/rechoir/-/rechoir-0.8.0.tgz", - "integrity": "sha512-/vxpCXddiX8NGfGO/mTafwjq4aFa/71pvamip0++IQk3zG8cbCj0fifNPrjjF1XMXUne91jL9OoxmdykoEtifQ==", - "dev": true, - "dependencies": { - "resolve": "^1.20.0" - }, - "engines": { - "node": ">= 10.13.0" - } - }, "node_modules/regexp-tree": { - "version": "0.1.24", - "resolved": "https://registry.npmjs.org/regexp-tree/-/regexp-tree-0.1.24.tgz", - "integrity": "sha512-s2aEVuLhvnVJW6s/iPgEGK6R+/xngd2jNQ+xy4bXNDKxZKJH6jpPHY6kVeVv1IeLCHgswRj+Kl3ELaDjG6V1iw==", + "version": "0.1.27", + "resolved": "https://registry.npmjs.org/regexp-tree/-/regexp-tree-0.1.27.tgz", + "integrity": "sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==", "dev": true, "bin": { "regexp-tree": "bin/regexp-tree" } }, "node_modules/regexp.prototype.flags": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.4.3.tgz", - "integrity": "sha512-fjggEOO3slI6Wvgjwflkc4NFRCTZAu5CnNfBd5qOMYhWdn67nJBBu34/TkD++eeFmd8C9r9jfXJ27+nSiRkSUA==", + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.1.tgz", + "integrity": "sha512-sy6TXMN+hnP/wMy+ISxg3krXx7BAtWVO4UouuCN/ziM9UEne0euamVNafDfvC83bRNr95y0V5iijeDQFUNpvrg==", "dev": true, "dependencies": { "call-bind": "^1.0.2", - "define-properties": "^1.1.3", - "functions-have-names": "^1.2.2" + "define-properties": "^1.2.0", + "set-function-name": "^2.0.0" }, "engines": { "node": ">= 0.4" @@ -4901,22 +3901,10 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/regexpp": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.2.0.tgz", - "integrity": "sha512-pq2bWo9mVD43nbts2wGv17XLiNLya+GklZ8kaDLV2Z08gDCsGpnKn9BFMepvWuHCbyVvY7J5o5+BVvoQbmlJLg==", - "dev": true, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/mysticatea" - } - }, "node_modules/regjsparser": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/regjsparser/-/regjsparser-0.9.1.tgz", - "integrity": "sha512-dQUtn90WanSNl+7mQKcXAgZxvUe7Z0SqXlgzv0za4LwiUhyzBC58yQO3liFoUgu8GiJVInAhJjkj1N0EtQ5nkQ==", + "version": "0.10.0", + "resolved": "https://registry.npmjs.org/regjsparser/-/regjsparser-0.10.0.tgz", + "integrity": "sha512-qx+xQGZVsy55CH0a1hiVwHmqjLryfh7wQyF5HO07XJ9f7dQMY/gPQHhlyDkIzJKC+x2fUCpCcUODUUUFrm7SHA==", "dev": true, "dependencies": { "jsesc": "~0.5.0" @@ -4960,27 +3948,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/resolve-cwd": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", - "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", - "dev": true, - "dependencies": { - "resolve-from": "^5.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/resolve-cwd/node_modules/resolve-from": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", - "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", - "dev": true, - "engines": { - "node": ">=8" - } - }, "node_modules/resolve-from": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", @@ -5015,16 +3982,6 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/ripemd160": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/ripemd160/-/ripemd160-2.0.2.tgz", - "integrity": "sha512-ii4iagi25WusVoiC4B4lq7pbXfAp3D9v5CwfkY33vffw2+pkDjY1D8GaN7spsxvCSx8dkPqOZCEZyfxcmJG2IA==", - "dev": true, - "dependencies": { - "hash-base": "^3.0.0", - "inherits": "^2.0.1" - } - }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -5048,21 +4005,36 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/safe-array-concat": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/safe-array-concat/-/safe-array-concat-1.0.1.tgz", + "integrity": "sha512-6XbUAseYE2KtOuGueyeobCySj9L4+66Tn6KQMOPQJrAJEowYKW/YR/MGJZl7FdydUdaFu4LYyDZjxf4/Nmo23Q==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "get-intrinsic": "^1.2.1", + "has-symbols": "^1.0.3", + "isarray": "^2.0.5" + }, + "engines": { + "node": ">=0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safe-array-concat/node_modules/isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true + }, "node_modules/safe-buffer": { "version": "5.1.2", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", "dev": true }, - "node_modules/safe-regex": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/safe-regex/-/safe-regex-2.1.1.tgz", - "integrity": "sha512-rx+x8AMzKb5Q5lQ95Zoi6ZbJqwCLkqi3XuJXp5P3rT8OEc6sZCJG5AE5dU3lsgRr/F4Bs31jSlVN+j5KrsGu9A==", - "dev": true, - "dependencies": { - "regexp-tree": "~0.1.1" - } - }, "node_modules/safe-regex-test": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/safe-regex-test/-/safe-regex-test-1.0.0.tgz", @@ -5077,34 +4049,10 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/safer-buffer": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", - "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "dev": true - }, - "node_modules/schema-utils": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.1.1.tgz", - "integrity": "sha512-Y5PQxS4ITlC+EahLuXaY86TXfR7Dc5lw294alXOq86JAHCihAIZfqv8nNCWvaEJvaC51uN9hbLGeV0cFBdH+Fw==", - "dev": true, - "dependencies": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -5116,52 +4064,32 @@ "node": ">=10" } }, - "node_modules/serialize-javascript": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", - "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", - "dev": true, - "dependencies": { - "randombytes": "^2.1.0" - } - }, "node_modules/set-blocking": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", "dev": true }, - "node_modules/setimmediate": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", - "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==", - "dev": true - }, - "node_modules/sha.js": { - "version": "2.4.11", - "resolved": "https://registry.npmjs.org/sha.js/-/sha.js-2.4.11.tgz", - "integrity": "sha512-QMEp5B7cftE7APOjk5Y6xgrbWu+WkLVQwk8JNjZ8nKRciZaByEW6MubieAiToS7+dwvrjGhH8jRXz3MVd0AYqQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - }, - "bin": { - "sha.js": "bin.js" - } - }, - "node_modules/shallow-clone": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/shallow-clone/-/shallow-clone-3.0.1.tgz", - "integrity": "sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA==", + "node_modules/set-function-name": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/set-function-name/-/set-function-name-2.0.1.tgz", + "integrity": "sha512-tMNCiqYVkXIZgc2Hnoy2IvC/f8ezc5koaRFkCjrpWzGpCd3qbZXPzVy9MAZzK1ch/X0jvSkojys3oqJN0qCmdA==", "dev": true, "dependencies": { - "kind-of": "^6.0.2" + "define-data-property": "^1.0.1", + "functions-have-names": "^1.2.3", + "has-property-descriptors": "^1.0.0" }, "engines": { - "node": ">=8" + "node": ">= 0.4" } }, + "node_modules/setimmediate": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", + "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==", + "dev": true + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -5203,46 +4131,13 @@ "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", "dev": true }, - "node_modules/sirv": { - "version": "1.0.19", - "resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz", - "integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==", - "dev": true, - "dependencies": { - "@polka/url": "^1.0.0-next.20", - "mrmime": "^1.0.0", - "totalist": "^1.0.0" - }, - "engines": { - "node": ">= 10" - } - }, "node_modules/slash": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", "dev": true, "engines": { - "node": ">=8" - } - }, - "node_modules/source-map": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", - "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/source-map-support": { - "version": "0.5.21", - "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", - "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", - "dev": true, - "dependencies": { - "buffer-from": "^1.0.0", - "source-map": "^0.6.0" + "node": ">=8" } }, "node_modules/spdx-correct": { @@ -5277,56 +4172,6 @@ "integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==", "dev": true }, - "node_modules/stream-browserify": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/stream-browserify/-/stream-browserify-3.0.0.tgz", - "integrity": "sha512-H73RAHsVBapbim0tU2JwwOiXUj+fikfiaoYAKHF3VJfA0pe2BCzkhAHBlLG6REzE+2WNZcxOXjK7lkso+9euLA==", - "dev": true, - "dependencies": { - "inherits": "~2.0.4", - "readable-stream": "^3.5.0" - } - }, - "node_modules/stream-browserify/node_modules/readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/stream-http": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/stream-http/-/stream-http-3.2.0.tgz", - "integrity": "sha512-Oq1bLqisTyK3TSCXpPbT4sdeYNdmyZJv1LxpEm2vu1ZhK89kSE5YXwZc3cWk0MagGaKriBh9mCFbVGtO+vY29A==", - "dev": true, - "dependencies": { - "builtin-status-codes": "^3.0.0", - "inherits": "^2.0.4", - "readable-stream": "^3.6.0", - "xtend": "^4.0.2" - } - }, - "node_modules/stream-http/node_modules/readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/string_decoder": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", @@ -5350,29 +4195,46 @@ "node": ">=8" } }, + "node_modules/string.prototype.trim": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/string.prototype.trim/-/string.prototype.trim-1.2.8.tgz", + "integrity": "sha512-lfjY4HcixfQXOfaqCvcBuOIapyaroTXhbkfJN3gcB1OtyupngWK4sEET9Knd0cXd28kTUqu/kHoV4HKSJdnjiQ==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/string.prototype.trimend": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.6.tgz", - "integrity": "sha512-JySq+4mrPf9EsDBEDYMOb/lM7XQLulwg5R/m1r0PXEFqrV0qHvl58sdTilSXtKOflCsK2E8jxf+GKC0T07RWwQ==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.7.tgz", + "integrity": "sha512-Ni79DqeB72ZFq1uH/L6zJ+DKZTkOtPIHovb3YZHQViE+HDouuU4mBrLOLDn5Dde3RF8qw5qVETEjhu9locMLvA==", "dev": true, "dependencies": { "call-bind": "^1.0.2", - "define-properties": "^1.1.4", - "es-abstract": "^1.20.4" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" }, "funding": { "url": "https://github.com/sponsors/ljharb" } }, "node_modules/string.prototype.trimstart": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.6.tgz", - "integrity": "sha512-omqjMDaY92pbn5HOX7f9IccLA+U1tA9GvtU4JrodiXFfYB7jPzzHpRzpglLAjtUV6bB557zwClJezTqnAiYnQA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.7.tgz", + "integrity": "sha512-NGhtDFu3jCEm7B4Fy0DpLewdJQOZcQ0rGbwQ/+stjnrp2i+rlKeCvos9hOIeCmqwratM47OBxY7uFZzjxHXmrg==", "dev": true, "dependencies": { "call-bind": "^1.0.2", - "define-properties": "^1.1.4", - "es-abstract": "^1.20.4" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -5447,85 +4309,12 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", - "dev": true, - "engines": { - "node": ">=6" - } - }, - "node_modules/terser": { - "version": "5.16.5", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.16.5.tgz", - "integrity": "sha512-qcwfg4+RZa3YvlFh0qjifnzBHjKGNbtDo9yivMqMFDy9Q6FSaQWSB/j1xKhsoUFJIqDOM3TsN6D5xbrMrFcHbg==", - "dev": true, - "dependencies": { - "@jridgewell/source-map": "^0.3.2", - "acorn": "^8.5.0", - "commander": "^2.20.0", - "source-map-support": "~0.5.20" - }, - "bin": { - "terser": "bin/terser" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/terser-webpack-plugin": { - "version": "5.3.6", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.6.tgz", - "integrity": "sha512-kfLFk+PoLUQIbLmB1+PZDMRSZS99Mp+/MHqDNmMA6tOItzRt+Npe3E+fsMs5mfcM0wCtrrdU387UnV+vnSffXQ==", - "dev": true, - "dependencies": { - "@jridgewell/trace-mapping": "^0.3.14", - "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.0", - "terser": "^5.14.1" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - }, - "peerDependencies": { - "webpack": "^5.1.0" - }, - "peerDependenciesMeta": { - "@swc/core": { - "optional": true - }, - "esbuild": { - "optional": true - }, - "uglify-js": { - "optional": true - } - } - }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", "dev": true }, - "node_modules/timers-browserify": { - "version": "2.0.12", - "resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz", - "integrity": "sha512-9phl76Cqm6FhSX9Xe1ZUAMLtm1BLkKj2Qd5ApyWkXzsMRaA7dgr81kf4wJmQf/hAvg8EEyJxDo3du/0KlhPiKQ==", - "dev": true, - "dependencies": { - "setimmediate": "^1.0.4" - }, - "engines": { - "node": ">=0.6.0" - } - }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -5538,32 +4327,16 @@ "node": ">=8.0" } }, - "node_modules/totalist": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz", - "integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==", - "dev": true, - "engines": { - "node": ">=6" - } - }, - "node_modules/ts-loader": { - "version": "9.4.2", - "resolved": "https://registry.npmjs.org/ts-loader/-/ts-loader-9.4.2.tgz", - "integrity": "sha512-OmlC4WVmFv5I0PpaxYb+qGeGOdm5giHU7HwDDUjw59emP2UYMHy9fFSDcYgSNoH8sXcj4hGCSEhlDZ9ULeDraA==", + "node_modules/ts-api-utils": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.0.3.tgz", + "integrity": "sha512-wNMeqtMz5NtwpT/UZGY5alT+VoKdSsOOP/kqHFcUW1P/VRhH2wJ48+DN2WwUliNbQ976ETwDL0Ifd2VVvgonvg==", "dev": true, - "dependencies": { - "chalk": "^4.1.0", - "enhanced-resolve": "^5.0.0", - "micromatch": "^4.0.0", - "semver": "^7.3.4" - }, "engines": { - "node": ">=12.0.0" + "node": ">=16.13.0" }, "peerDependencies": { - "typescript": "*", - "webpack": "^5.0.0" + "typescript": ">=4.2.0" } }, "node_modules/tsconfig-paths": { @@ -5578,33 +4351,6 @@ "strip-bom": "^3.0.0" } }, - "node_modules/tslib": { - "version": "1.14.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", - "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", - "dev": true - }, - "node_modules/tsutils": { - "version": "3.21.0", - "resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz", - "integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==", - "dev": true, - "dependencies": { - "tslib": "^1.8.1" - }, - "engines": { - "node": ">= 6" - }, - "peerDependencies": { - "typescript": ">=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta" - } - }, - "node_modules/tty-browserify": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/tty-browserify/-/tty-browserify-0.0.1.tgz", - "integrity": "sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==", - "dev": true - }, "node_modules/type-check": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", @@ -5629,6 +4375,57 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/typed-array-buffer": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-buffer/-/typed-array-buffer-1.0.0.tgz", + "integrity": "sha512-Y8KTSIglk9OZEr8zywiIHG/kmQ7KWyjseXs1CbSo8vC42w7hg2HgYTxSWwP0+is7bWDc1H+Fo026CpHFwm8tkw==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "get-intrinsic": "^1.2.1", + "is-typed-array": "^1.1.10" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/typed-array-byte-length": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-byte-length/-/typed-array-byte-length-1.0.0.tgz", + "integrity": "sha512-Or/+kvLxNpeQ9DtSydonMxCx+9ZXOswtwJn17SNLvhptaXYDJvkFFP5zbfU/uLmvnBJlI4yrnXRxpdWH/M5tNA==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.2", + "for-each": "^0.3.3", + "has-proto": "^1.0.1", + "is-typed-array": "^1.1.10" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typed-array-byte-offset": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-byte-offset/-/typed-array-byte-offset-1.0.0.tgz", + "integrity": "sha512-RD97prjEt9EL8YgAgpOkf3O4IF9lhJFr9g0htQkm0rchFp/Vx7LW5Q8fSXXub7BXAODyUQohRMyOc3faCPd0hg==", + "dev": true, + "dependencies": { + "available-typed-arrays": "^1.0.5", + "call-bind": "^1.0.2", + "for-each": "^0.3.3", + "has-proto": "^1.0.1", + "is-typed-array": "^1.1.10" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/typed-array-length": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/typed-array-length/-/typed-array-length-1.0.4.tgz", @@ -5644,16 +4441,16 @@ } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/unbox-primitive": { @@ -5680,32 +4477,6 @@ "node": ">= 10.0.0" } }, - "node_modules/update-browserslist-db": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.10.tgz", - "integrity": "sha512-OztqDenkfFkbSG+tRxBeAnCVPckDBcvibKd35yDONx6OU8N7sqgwc7rCbkJ/WcYtVRZ4ba68d6byhC21GFh7sQ==", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/browserslist" - } - ], - "dependencies": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" - }, - "bin": { - "browserslist-lint": "cli.js" - }, - "peerDependencies": { - "browserslist": ">= 4.21.0" - } - }, "node_modules/uri-js": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", @@ -5715,35 +4486,6 @@ "punycode": "^2.1.0" } }, - "node_modules/url": { - "version": "0.11.0", - "resolved": "https://registry.npmjs.org/url/-/url-0.11.0.tgz", - "integrity": "sha512-kbailJa29QrtXnxgq+DdCEGlbTeYM2eJUxsz6vjZavrCYPMIFHMKQmSKYAIuUK2i7hgPm28a8piX5NTUtM/LKQ==", - "dev": true, - "dependencies": { - "punycode": "1.3.2", - "querystring": "0.2.0" - } - }, - "node_modules/url/node_modules/punycode": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-1.3.2.tgz", - "integrity": "sha512-RofWgt/7fL5wP1Y7fxE7/EmTLzQVnB0ycyibJ0OOHIlJqTNzglYFxVwETOcIoJqJmpDXJ9xImDv+Fq34F/d4Dw==", - "dev": true - }, - "node_modules/util": { - "version": "0.12.5", - "resolved": "https://registry.npmjs.org/util/-/util-0.12.5.tgz", - "integrity": "sha512-kZf/K6hEIrWHI6XqOFUiiMa+79wE/D8Q+NCNAWclkyg3b4d2k7s0QGepNjiABc+aR3N1PAyHL7p6UcLY6LmrnA==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "is-arguments": "^1.0.4", - "is-generator-function": "^1.0.7", - "is-typed-array": "^1.1.3", - "which-typed-array": "^1.1.2" - } - }, "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", @@ -5753,186 +4495,11 @@ "node_modules/validate-npm-package-license": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz", - "integrity": "sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==", - "dev": true, - "dependencies": { - "spdx-correct": "^3.0.0", - "spdx-expression-parse": "^3.0.0" - } - }, - "node_modules/vm-browserify": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/vm-browserify/-/vm-browserify-1.1.2.tgz", - "integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==", - "dev": true - }, - "node_modules/watchpack": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", - "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", - "dev": true, - "dependencies": { - "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.1.2" - }, - "engines": { - "node": ">=10.13.0" - } - }, - "node_modules/webpack": { - "version": "5.76.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.76.0.tgz", - "integrity": "sha512-l5sOdYBDunyf72HW8dF23rFtWq/7Zgvt/9ftMof71E/yUb1YLOBmTgA2K4vQthB3kotMrSj609txVE0dnr2fjA==", - "dev": true, - "dependencies": { - "@types/eslint-scope": "^3.7.3", - "@types/estree": "^0.0.51", - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/wasm-edit": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1", - "acorn": "^8.7.1", - "acorn-import-assertions": "^1.7.6", - "browserslist": "^4.14.5", - "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.10.0", - "es-module-lexer": "^0.9.0", - "eslint-scope": "5.1.1", - "events": "^3.2.0", - "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.2.9", - "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", - "mime-types": "^2.1.27", - "neo-async": "^2.6.2", - "schema-utils": "^3.1.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.1.3", - "watchpack": "^2.4.0", - "webpack-sources": "^3.2.3" - }, - "bin": { - "webpack": "bin/webpack.js" - }, - "engines": { - "node": ">=10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - }, - "peerDependenciesMeta": { - "webpack-cli": { - "optional": true - } - } - }, - "node_modules/webpack-bundle-analyzer": { - "version": "4.8.0", - "resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.8.0.tgz", - "integrity": "sha512-ZzoSBePshOKhr+hd8u6oCkZVwpVaXgpw23ScGLFpR6SjYI7+7iIWYarjN6OEYOfRt8o7ZyZZQk0DuMizJ+LEIg==", - "dev": true, - "dependencies": { - "@discoveryjs/json-ext": "0.5.7", - "acorn": "^8.0.4", - "acorn-walk": "^8.0.0", - "chalk": "^4.1.0", - "commander": "^7.2.0", - "gzip-size": "^6.0.0", - "lodash": "^4.17.20", - "opener": "^1.5.2", - "sirv": "^1.0.7", - "ws": "^7.3.1" - }, - "bin": { - "webpack-bundle-analyzer": "lib/bin/analyzer.js" - }, - "engines": { - "node": ">= 10.13.0" - } - }, - "node_modules/webpack-bundle-analyzer/node_modules/commander": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz", - "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==", - "dev": true, - "engines": { - "node": ">= 10" - } - }, - "node_modules/webpack-cli": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/webpack-cli/-/webpack-cli-5.0.1.tgz", - "integrity": "sha512-S3KVAyfwUqr0Mo/ur3NzIp6jnerNpo7GUO6so51mxLi1spqsA17YcMXy0WOIJtBSnj748lthxC6XLbNKh/ZC+A==", - "dev": true, - "dependencies": { - "@discoveryjs/json-ext": "^0.5.0", - "@webpack-cli/configtest": "^2.0.1", - "@webpack-cli/info": "^2.0.1", - "@webpack-cli/serve": "^2.0.1", - "colorette": "^2.0.14", - "commander": "^9.4.1", - "cross-spawn": "^7.0.3", - "envinfo": "^7.7.3", - "fastest-levenshtein": "^1.0.12", - "import-local": "^3.0.2", - "interpret": "^3.1.1", - "rechoir": "^0.8.0", - "webpack-merge": "^5.7.3" - }, - "bin": { - "webpack-cli": "bin/cli.js" - }, - "engines": { - "node": ">=14.15.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - }, - "peerDependencies": { - "webpack": "5.x.x" - }, - "peerDependenciesMeta": { - "@webpack-cli/generators": { - "optional": true - }, - "webpack-bundle-analyzer": { - "optional": true - }, - "webpack-dev-server": { - "optional": true - } - } - }, - "node_modules/webpack-cli/node_modules/commander": { - "version": "9.5.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-9.5.0.tgz", - "integrity": "sha512-KRs7WVDKg86PWiuAqhDrAQnTXZKraVcCc6vFdL14qrZ/DcWwuRo7VoiYXalXO7S5GKpqYiVEwCbgFDfxNHKJBQ==", - "dev": true, - "engines": { - "node": "^12.20.0 || >=14" - } - }, - "node_modules/webpack-merge": { - "version": "5.8.0", - "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.8.0.tgz", - "integrity": "sha512-/SaI7xY0831XwP6kzuwhKWVKDP9t1QY1h65lAFLbZqMPIuYcD9QAW4u9STIbU9kaJbPBB/geU/gLr1wDjOhQ+Q==", - "dev": true, - "dependencies": { - "clone-deep": "^4.0.1", - "wildcard": "^2.0.0" - }, - "engines": { - "node": ">=10.0.0" - } - }, - "node_modules/webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "integrity": "sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==", "dev": true, - "engines": { - "node": ">=10.13.0" + "dependencies": { + "spdx-correct": "^3.0.0", + "spdx-expression-parse": "^3.0.0" } }, "node_modules/which": { @@ -5967,17 +4534,16 @@ } }, "node_modules/which-typed-array": { - "version": "1.1.9", - "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.9.tgz", - "integrity": "sha512-w9c4xkx6mPidwp7180ckYWfMmvxpjlZuIudNtDf4N/tTAUB8VJbX25qZoAsrtGuYNnGw3pa0AXgbGKRB8/EceA==", + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.11.tgz", + "integrity": "sha512-qe9UWWpkeG5yzZ0tNYxDmd7vo58HDBc39mZ0xWWpolAGADdFOzkfamWLDxkOWcvHQKVmdTyQdLD4NOfjLWTKew==", "dev": true, "dependencies": { "available-typed-arrays": "^1.0.5", "call-bind": "^1.0.2", "for-each": "^0.3.3", "gopd": "^1.0.1", - "has-tostringtag": "^1.0.0", - "is-typed-array": "^1.1.10" + "has-tostringtag": "^1.0.0" }, "engines": { "node": ">= 0.4" @@ -5995,41 +4561,6 @@ "string-width": "^1.0.2 || 2 || 3 || 4" } }, - "node_modules/wildcard": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/wildcard/-/wildcard-2.0.0.tgz", - "integrity": "sha512-JcKqAHLPxcdb9KM49dufGXn2x3ssnfjbcaQdLlfZsL9rH9wgDQjUtDxbo8NE0F6SFvydeu1VhZe7hZuHsB2/pw==", - "dev": true - }, - "node_modules/word-wrap": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.4.tgz", - "integrity": "sha512-2V81OA4ugVo5pRo46hAoD2ivUJx8jXmWXfUkY4KFNw0hEptvN0QfH3K4nHiwzGeKl5rFKedV48QVoqYavy4YpA==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/worker-loader": { - "version": "3.0.8", - "resolved": "https://registry.npmjs.org/worker-loader/-/worker-loader-3.0.8.tgz", - "integrity": "sha512-XQyQkIFeRVC7f7uRhFdNMe/iJOdO6zxAaR3EWbDp45v3mDhrTi+++oswKNxShUNjPC/1xUp5DB29YKLhFo129g==", - "dev": true, - "dependencies": { - "loader-utils": "^2.0.0", - "schema-utils": "^3.0.0" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - }, - "peerDependencies": { - "webpack": "^4.0.0 || ^5.0.0" - } - }, "node_modules/workerpool": { "version": "6.2.1", "resolved": "https://registry.npmjs.org/workerpool/-/workerpool-6.2.1.tgz", @@ -6059,36 +4590,6 @@ "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", "dev": true }, - "node_modules/ws": { - "version": "7.5.9", - "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", - "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", - "dev": true, - "engines": { - "node": ">=8.3.0" - }, - "peerDependencies": { - "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" - }, - "peerDependenciesMeta": { - "bufferutil": { - "optional": true - }, - "utf-8-validate": { - "optional": true - } - } - }, - "node_modules/xtend": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", - "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", - "dev": true, - "engines": { - "node": ">=0.4" - } - }, "node_modules/y18n": { "version": "5.0.8", "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", @@ -6160,6 +4661,12 @@ } }, "dependencies": { + "@aashutoshrathi/word-wrap": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz", + "integrity": "sha512-1Yjs2SvM8TflER/OD3cOjhWWOZb58A2t7wpE2S9XfBYTiIl+XFhQG2bjy4Pu1I+EAlCNUzRDYDdFwFYUKvXcIA==", + "dev": true + }, "@babel/code-frame": { "version": "7.18.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.18.6.tgz", @@ -6170,9 +4677,9 @@ } }, "@babel/helper-validator-identifier": { - "version": "7.19.1", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz", - "integrity": "sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w==", + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz", + "integrity": "sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==", "dev": true }, "@babel/highlight": { @@ -6244,41 +4751,195 @@ } } }, - "@discoveryjs/json-ext": { - "version": "0.5.7", - "resolved": "https://registry.npmjs.org/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz", - "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", - "dev": true - }, "@es-joy/jsdoccomment": { - "version": "0.36.1", - "resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.36.1.tgz", - "integrity": "sha512-922xqFsTpHs6D0BUiG4toiyPOMc8/jafnWKxz1KWgS4XzKPy2qXf1Pe6UFuNSCQqt6tOuhAWXBNuuyUhJmw9Vg==", + "version": "0.40.1", + "resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.40.1.tgz", + "integrity": "sha512-YORCdZSusAlBrFpZ77pJjc5r1bQs5caPWtAu+WWmiSo+8XaUzseapVrfAtiRFbQWnrBxxLLEwF6f6ZG/UgCQCg==", "dev": true, "requires": { - "comment-parser": "1.3.1", - "esquery": "^1.4.0", - "jsdoc-type-pratt-parser": "~3.1.0" + "comment-parser": "1.4.0", + "esquery": "^1.5.0", + "jsdoc-type-pratt-parser": "~4.0.0" } }, + "@esbuild/android-arm": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.19.3.tgz", + "integrity": "sha512-Lemgw4io4VZl9GHJmjiBGzQ7ONXRfRPHcUEerndjwiSkbxzrpq0Uggku5MxxrXdwJ+pTj1qyw4jwTu7hkPsgIA==", + "dev": true, + "optional": true + }, + "@esbuild/android-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.19.3.tgz", + "integrity": "sha512-w+Akc0vv5leog550kjJV9Ru+MXMR2VuMrui3C61mnysim0gkFCPOUTAfzTP0qX+HpN9Syu3YA3p1hf3EPqObRw==", + "dev": true, + "optional": true + }, + "@esbuild/android-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.19.3.tgz", + "integrity": "sha512-FKQJKkK5MXcBHoNZMDNUAg1+WcZlV/cuXrWCoGF/TvdRiYS4znA0m5Il5idUwfxrE20bG/vU1Cr5e1AD6IEIjQ==", + "dev": true, + "optional": true + }, + "@esbuild/darwin-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.3.tgz", + "integrity": "sha512-kw7e3FXU+VsJSSSl2nMKvACYlwtvZB8RUIeVShIEY6PVnuZ3c9+L9lWB2nWeeKWNNYDdtL19foCQ0ZyUL7nqGw==", + "dev": true, + "optional": true + }, + "@esbuild/darwin-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.19.3.tgz", + "integrity": "sha512-tPfZiwF9rO0jW6Jh9ipi58N5ZLoSjdxXeSrAYypy4psA2Yl1dAMhM71KxVfmjZhJmxRjSnb29YlRXXhh3GqzYw==", + "dev": true, + "optional": true + }, + "@esbuild/freebsd-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.19.3.tgz", + "integrity": "sha512-ERDyjOgYeKe0Vrlr1iLrqTByB026YLPzTytDTz1DRCYM+JI92Dw2dbpRHYmdqn6VBnQ9Bor6J8ZlNwdZdxjlSg==", + "dev": true, + "optional": true + }, + "@esbuild/freebsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.19.3.tgz", + "integrity": "sha512-nXesBZ2Ad1qL+Rm3crN7NmEVJ5uvfLFPLJev3x1j3feCQXfAhoYrojC681RhpdOph8NsvKBBwpYZHR7W0ifTTA==", + "dev": true, + "optional": true + }, + "@esbuild/linux-arm": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.19.3.tgz", + "integrity": "sha512-zr48Cg/8zkzZCzDHNxXO/89bf9e+r4HtzNUPoz4GmgAkF1gFAFmfgOdCbR8zMbzFDGb1FqBBhdXUpcTQRYS1cQ==", + "dev": true, + "optional": true + }, + "@esbuild/linux-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.19.3.tgz", + "integrity": "sha512-qXvYKmXj8GcJgWq3aGvxL/JG1ZM3UR272SdPU4QSTzD0eymrM7leiZH77pvY3UetCy0k1xuXZ+VPvoJNdtrsWQ==", + "dev": true, + "optional": true + }, + "@esbuild/linux-ia32": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.19.3.tgz", + "integrity": "sha512-7XlCKCA0nWcbvYpusARWkFjRQNWNGlt45S+Q18UeS///K6Aw8bB2FKYe9mhVWy/XLShvCweOLZPrnMswIaDXQA==", + "dev": true, + "optional": true + }, + "@esbuild/linux-loong64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.19.3.tgz", + "integrity": "sha512-qGTgjweER5xqweiWtUIDl9OKz338EQqCwbS9c2Bh5jgEH19xQ1yhgGPNesugmDFq+UUSDtWgZ264st26b3de8A==", + "dev": true, + "optional": true + }, + "@esbuild/linux-mips64el": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.19.3.tgz", + "integrity": "sha512-gy1bFskwEyxVMFRNYSvBauDIWNggD6pyxUksc0MV9UOBD138dKTzr8XnM2R4mBsHwVzeuIH8X5JhmNs2Pzrx+A==", + "dev": true, + "optional": true + }, + "@esbuild/linux-ppc64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.19.3.tgz", + "integrity": "sha512-UrYLFu62x1MmmIe85rpR3qou92wB9lEXluwMB/STDzPF9k8mi/9UvNsG07Tt9AqwPQXluMQ6bZbTzYt01+Ue5g==", + "dev": true, + "optional": true + }, + "@esbuild/linux-riscv64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.19.3.tgz", + "integrity": "sha512-9E73TfyMCbE+1AwFOg3glnzZ5fBAFK4aawssvuMgCRqCYzE0ylVxxzjEfut8xjmKkR320BEoMui4o/t9KA96gA==", + "dev": true, + "optional": true + }, + "@esbuild/linux-s390x": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.19.3.tgz", + "integrity": "sha512-LlmsbuBdm1/D66TJ3HW6URY8wO6IlYHf+ChOUz8SUAjVTuaisfuwCOAgcxo3Zsu3BZGxmI7yt//yGOxV+lHcEA==", + "dev": true, + "optional": true + }, + "@esbuild/linux-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.19.3.tgz", + "integrity": "sha512-ogV0+GwEmvwg/8ZbsyfkYGaLACBQWDvO0Kkh8LKBGKj9Ru8VM39zssrnu9Sxn1wbapA2qNS6BiLdwJZGouyCwQ==", + "dev": true, + "optional": true + }, + "@esbuild/netbsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.19.3.tgz", + "integrity": "sha512-o1jLNe4uzQv2DKXMlmEzf66Wd8MoIhLNO2nlQBHLtWyh2MitDG7sMpfCO3NTcoTMuqHjfufgUQDFRI5C+xsXQw==", + "dev": true, + "optional": true + }, + "@esbuild/openbsd-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.19.3.tgz", + "integrity": "sha512-AZJCnr5CZgZOdhouLcfRdnk9Zv6HbaBxjcyhq0StNcvAdVZJSKIdOiPB9az2zc06ywl0ePYJz60CjdKsQacp5Q==", + "dev": true, + "optional": true + }, + "@esbuild/sunos-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.19.3.tgz", + "integrity": "sha512-Acsujgeqg9InR4glTRvLKGZ+1HMtDm94ehTIHKhJjFpgVzZG9/pIcWW/HA/DoMfEyXmANLDuDZ2sNrWcjq1lxw==", + "dev": true, + "optional": true + }, + "@esbuild/win32-arm64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.19.3.tgz", + "integrity": "sha512-FSrAfjVVy7TifFgYgliiJOyYynhQmqgPj15pzLyJk8BUsnlWNwP/IAy6GAiB1LqtoivowRgidZsfpoYLZH586A==", + "dev": true, + "optional": true + }, + "@esbuild/win32-ia32": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.19.3.tgz", + "integrity": "sha512-xTScXYi12xLOWZ/sc5RBmMN99BcXp/eEf7scUC0oeiRoiT5Vvo9AycuqCp+xdpDyAU+LkrCqEpUS9fCSZF8J3Q==", + "dev": true, + "optional": true + }, + "@esbuild/win32-x64": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.19.3.tgz", + "integrity": "sha512-FbUN+0ZRXsypPyWE2IwIkVjDkDnJoMJARWOcFZn4KPPli+QnKqF0z1anvfaYe3ev5HFCpRDLLBDHyOALLppWHw==", + "dev": true, + "optional": true + }, "@eslint-community/eslint-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.2.0.tgz", - "integrity": "sha512-gB8T4H4DEfX2IV9zGDJPOBgP1e/DbfCPDTtEqUMckpvzS1OYtva8JdFYBqMwYk7xAQ429WGF/UPqn8uQ//h2vQ==", + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", + "integrity": "sha512-1/sA4dwrzBAyeUoQ6oxahHKmrZvsnLCg4RfxW3ZFGGmQkSNQPFNLV9CUEFQP1x9EYXHTo5p6xdhZM1Ne9p/AfA==", "dev": true, "requires": { "eslint-visitor-keys": "^3.3.0" } }, + "@eslint-community/regexpp": { + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.9.1.tgz", + "integrity": "sha512-Y27x+MBLjXa+0JWDhykM3+JE+il3kHKAEqabfEWq3SDhZjLYb6/BHL/JKFnH3fe207JaXkyDo685Oc2Glt6ifA==", + "dev": true + }, "@eslint/eslintrc": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.0.0.tgz", - "integrity": "sha512-fluIaaV+GyV24CCu/ggiHdV+j4RNh85yQnAYS/G2mZODZgGmmlrgCydjUcV3YvxCm9x8nMAfThsqTni4KiXT4A==", + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.2.tgz", + "integrity": "sha512-+wvgpDsrB1YqAMdEUCcnTlpfVBH7Vqn6A/NT3D8WVXFIaKMlErPIZT3oCIAVCOtarRpMtelZLqJeU3t7WY6X6g==", "dev": true, "requires": { "ajv": "^6.12.4", "debug": "^4.3.2", - "espree": "^9.4.0", + "espree": "^9.6.0", "globals": "^13.19.0", "ignore": "^5.2.0", "import-fresh": "^3.2.1", @@ -6288,15 +4949,15 @@ } }, "@eslint/js": { - "version": "8.35.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.35.0.tgz", - "integrity": "sha512-JXdzbRiWclLVoD8sNUjR443VVlYqiYmDVT6rGUEIEHU5YJW0gaVZwV2xgM7D4arkvASqD0IlLUVjHiFuxaftRw==", + "version": "8.51.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.51.0.tgz", + "integrity": "sha512-HxjQ8Qn+4SI3/AFv6sOrDB+g6PpUTDwSJiQqOrnneEk8L71161srI9gjzzZvYVbzHiVg/BvcH95+cK/zfIt4pg==", "dev": true }, "@humanwhocodes/config-array": { - "version": "0.11.8", - "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz", - "integrity": "sha512-UybHIJzJnR5Qc/MsD9Kr+RpO2h+/P1GhOwdiLPXK5TWk5sgTdu88bTD9UP+CKbPPh5Rni1u0GjAdYQLemG8g+g==", + "version": "0.11.11", + "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.11.tgz", + "integrity": "sha512-N2brEuAadi0CcdeMXUkhbZB84eskAc8MEX1By6qEchoVywSgXPIjou4rYsl0V3Hj0ZnuGycGCjdNgockbzeWNA==", "dev": true, "requires": { "@humanwhocodes/object-schema": "^1.2.1", @@ -6316,55 +4977,12 @@ "integrity": "sha512-ZnQMnLV4e7hDlUvw8H+U8ASL02SS2Gn6+9Ac3wGGLIe7+je2AeAOxPY+izIPJDfFDb7eDjev0Us8MO1iFRN8hA==", "dev": true }, - "@jridgewell/gen-mapping": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz", - "integrity": "sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A==", - "dev": true, - "requires": { - "@jridgewell/set-array": "^1.0.1", - "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" - } - }, - "@jridgewell/resolve-uri": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz", - "integrity": "sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w==", - "dev": true - }, - "@jridgewell/set-array": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", - "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", - "dev": true - }, - "@jridgewell/source-map": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.2.tgz", - "integrity": "sha512-m7O9o2uR8k2ObDysZYzdfhb08VuEml5oWGiosa1VdaPZ/A6QyPkAJuwN0Q1lhULOf6B7MtQmHENS743hWtCrgw==", - "dev": true, - "requires": { - "@jridgewell/gen-mapping": "^0.3.0", - "@jridgewell/trace-mapping": "^0.3.9" - } - }, - "@jridgewell/sourcemap-codec": { - "version": "1.4.14", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz", - "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==", + "@jspm/core": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@jspm/core/-/core-2.0.1.tgz", + "integrity": "sha512-Lg3PnLp0QXpxwLIAuuJboLeRaIhrgJjeuh797QADg3xz8wGLugQOS5DpsE8A6i6Adgzf+bacllkKZG3J0tGfDw==", "dev": true }, - "@jridgewell/trace-mapping": { - "version": "0.3.17", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz", - "integrity": "sha512-MCNzAp77qzKca9+W/+I0+sEpaUnZoeasnghNeVc41VZCEKaCH73Vq3BZZ/SzWIgrqE4H4ceI+p+b6C0mHf9T4g==", - "dev": true, - "requires": { - "@jridgewell/resolve-uri": "3.1.0", - "@jridgewell/sourcemap-codec": "1.4.14" - } - }, "@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -6391,42 +5009,10 @@ "fastq": "^1.6.0" } }, - "@polka/url": { - "version": "1.0.0-next.21", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz", - "integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==", - "dev": true - }, - "@types/eslint": { - "version": "8.21.1", - "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.21.1.tgz", - "integrity": "sha512-rc9K8ZpVjNcLs8Fp0dkozd5Pt2Apk1glO4Vgz8ix1u6yFByxfqo5Yavpy65o+93TAe24jr7v+eSBtFLvOQtCRQ==", - "dev": true, - "requires": { - "@types/estree": "*", - "@types/json-schema": "*" - } - }, - "@types/eslint-scope": { - "version": "3.7.4", - "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", - "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", - "dev": true, - "requires": { - "@types/eslint": "*", - "@types/estree": "*" - } - }, - "@types/estree": { - "version": "0.0.51", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-0.0.51.tgz", - "integrity": "sha512-CuPgU6f3eT/XgKKPqKd/gLZV1Xmvf1a2R5POBOGQa6uv82xpls89HU5zKeVoyR8XzHd1RGNOlQlvUe3CFkjWNQ==", - "dev": true - }, "@types/fs-extra": { - "version": "11.0.1", - "resolved": "https://registry.npmjs.org/@types/fs-extra/-/fs-extra-11.0.1.tgz", - "integrity": "sha512-MxObHvNl4A69ofaTRU8DFqvgzzv8s9yRtaPPm5gud9HDNvpB3GPQFvNuTWAI59B9huVGV5jXYJwbCsmBsOGYWA==", + "version": "11.0.2", + "resolved": "https://registry.npmjs.org/@types/fs-extra/-/fs-extra-11.0.2.tgz", + "integrity": "sha512-c0hrgAOVYr21EX8J0jBMXGLMgJqVf/v6yxi0dLaJboW9aQPh16Id+z6w2Tx1hm+piJOLv8xPfVKZCLfjPw/IMQ==", "dev": true, "requires": { "@types/jsonfile": "*", @@ -6434,9 +5020,9 @@ } }, "@types/json-schema": { - "version": "7.0.11", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.11.tgz", - "integrity": "sha512-wOuvG1SN4Us4rez+tylwwwCV1psiNVOkJeM3AUWUNWg/jDQY2+HE/444y5gc+jBmRqASOm2Oeh5c1axHobwRKQ==", + "version": "7.0.13", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", + "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", "dev": true }, "@types/json5": { @@ -6455,9 +5041,9 @@ } }, "@types/mocha": { - "version": "10.0.1", - "resolved": "https://registry.npmjs.org/@types/mocha/-/mocha-10.0.1.tgz", - "integrity": "sha512-/fvYntiO1GeICvqbQ3doGDIP97vWmvFt83GKguJ6prmQM2iXZfFcq6YE8KteFyRtX2/h5Hf91BYvPodJKFYv5Q==", + "version": "10.0.2", + "resolved": "https://registry.npmjs.org/@types/mocha/-/mocha-10.0.2.tgz", + "integrity": "sha512-NaHL0+0lLNhX6d9rs+NSt97WH/gIlRHmszXbQ/8/MV/eVcFNdeJ/GYhrFuUc8K7WuPhRhTSdMkCp8VMzhUq85w==", "dev": true }, "@types/node": { @@ -6479,289 +5065,111 @@ "dev": true }, "@types/semver": { - "version": "7.3.13", - "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.3.13.tgz", - "integrity": "sha512-21cFJr9z3g5dW8B0CVI9g2O9beqaThGQ6ZFBqHfwhzLDKUxaqTIy3vnfah/UPkfOiF2pLq+tGz+W8RyCskuslw==", + "version": "7.5.3", + "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.5.3.tgz", + "integrity": "sha512-OxepLK9EuNEIPxWNME+C6WwbRAOOI2o2BaQEGzz5Lu2e4Z5eDnEo+/aVEDMIXywoJitJ7xWd641wrGLZdtwRyw==", "dev": true }, "@typescript-eslint/eslint-plugin": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.54.1.tgz", - "integrity": "sha512-a2RQAkosH3d3ZIV08s3DcL/mcGc2M/UC528VkPULFxR9VnVPT8pBu0IyBAJJmVsCmhVfwQX1v6q+QGnmSe1bew==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-6.7.4.tgz", + "integrity": "sha512-DAbgDXwtX+pDkAHwiGhqP3zWUGpW49B7eqmgpPtg+BKJXwdct79ut9+ifqOFPJGClGKSHXn2PTBatCnldJRUoA==", "dev": true, "requires": { - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/type-utils": "5.54.1", - "@typescript-eslint/utils": "5.54.1", + "@eslint-community/regexpp": "^4.5.1", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/type-utils": "6.7.4", + "@typescript-eslint/utils": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4", - "grapheme-splitter": "^1.0.4", - "ignore": "^5.2.0", - "natural-compare-lite": "^1.4.0", - "regexpp": "^3.2.0", - "semver": "^7.3.7", - "tsutils": "^3.21.0" + "graphemer": "^1.4.0", + "ignore": "^5.2.4", + "natural-compare": "^1.4.0", + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" } }, "@typescript-eslint/parser": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.54.1.tgz", - "integrity": "sha512-8zaIXJp/nG9Ff9vQNh7TI+C3nA6q6iIsGJ4B4L6MhZ7mHnTMR4YP5vp2xydmFXIy8rpyIVbNAG44871LMt6ujg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-6.7.4.tgz", + "integrity": "sha512-I5zVZFY+cw4IMZUeNCU7Sh2PO5O57F7Lr0uyhgCJmhN/BuTlnc55KxPonR4+EM3GBdfiCyGZye6DgMjtubQkmA==", "dev": true, "requires": { - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/typescript-estree": "5.54.1", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/typescript-estree": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4" } }, "@typescript-eslint/scope-manager": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.54.1.tgz", - "integrity": "sha512-zWKuGliXxvuxyM71UA/EcPxaviw39dB2504LqAmFDjmkpO8qNLHcmzlh6pbHs1h/7YQ9bnsO8CCcYCSA8sykUg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.7.4.tgz", + "integrity": "sha512-SdGqSLUPTXAXi7c3Ob7peAGVnmMoGzZ361VswK2Mqf8UOYcODiYvs8rs5ILqEdfvX1lE7wEZbLyELCW+Yrql1A==", "dev": true, "requires": { - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/visitor-keys": "5.54.1" + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4" } }, "@typescript-eslint/type-utils": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.54.1.tgz", - "integrity": "sha512-WREHsTz0GqVYLIbzIZYbmUUr95DKEKIXZNH57W3s+4bVnuF1TKe2jH8ZNH8rO1CeMY3U4j4UQeqPNkHMiGem3g==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-6.7.4.tgz", + "integrity": "sha512-n+g3zi1QzpcAdHFP9KQF+rEFxMb2KxtnJGID3teA/nxKHOVi3ylKovaqEzGBbVY2pBttU6z85gp0D00ufLzViQ==", "dev": true, "requires": { - "@typescript-eslint/typescript-estree": "5.54.1", - "@typescript-eslint/utils": "5.54.1", + "@typescript-eslint/typescript-estree": "6.7.4", + "@typescript-eslint/utils": "6.7.4", "debug": "^4.3.4", - "tsutils": "^3.21.0" + "ts-api-utils": "^1.0.1" } }, "@typescript-eslint/types": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.54.1.tgz", - "integrity": "sha512-G9+1vVazrfAfbtmCapJX8jRo2E4MDXxgm/IMOF4oGh3kq7XuK3JRkOg6y2Qu1VsTRmWETyTkWt1wxy7X7/yLkw==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.7.4.tgz", + "integrity": "sha512-o9XWK2FLW6eSS/0r/tgjAGsYasLAnOWg7hvZ/dGYSSNjCh+49k5ocPN8OmG5aZcSJ8pclSOyVKP2x03Sj+RrCA==", "dev": true }, "@typescript-eslint/typescript-estree": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.54.1.tgz", - "integrity": "sha512-bjK5t+S6ffHnVwA0qRPTZrxKSaFYocwFIkZx5k7pvWfsB1I57pO/0M0Skatzzw1sCkjJ83AfGTL0oFIFiDX3bg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.7.4.tgz", + "integrity": "sha512-ty8b5qHKatlNYd9vmpHooQz3Vki3gG+3PchmtsA4TgrZBKWHNjWfkQid7K7xQogBqqc7/BhGazxMD5vr6Ha+iQ==", "dev": true, "requires": { - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/visitor-keys": "5.54.1", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/visitor-keys": "6.7.4", "debug": "^4.3.4", "globby": "^11.1.0", "is-glob": "^4.0.3", - "semver": "^7.3.7", - "tsutils": "^3.21.0" + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" } }, "@typescript-eslint/utils": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.54.1.tgz", - "integrity": "sha512-IY5dyQM8XD1zfDe5X8jegX6r2EVU5o/WJnLu/znLPWCBF7KNGC+adacXnt5jEYS9JixDcoccI6CvE4RCjHMzCQ==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.7.4.tgz", + "integrity": "sha512-PRQAs+HUn85Qdk+khAxsVV+oULy3VkbH3hQ8hxLRJXWBEd7iI+GbQxH5SEUSH7kbEoTp6oT1bOwyga24ELALTA==", "dev": true, "requires": { - "@types/json-schema": "^7.0.9", - "@types/semver": "^7.3.12", - "@typescript-eslint/scope-manager": "5.54.1", - "@typescript-eslint/types": "5.54.1", - "@typescript-eslint/typescript-estree": "5.54.1", - "eslint-scope": "^5.1.1", - "eslint-utils": "^3.0.0", - "semver": "^7.3.7" + "@eslint-community/eslint-utils": "^4.4.0", + "@types/json-schema": "^7.0.12", + "@types/semver": "^7.5.0", + "@typescript-eslint/scope-manager": "6.7.4", + "@typescript-eslint/types": "6.7.4", + "@typescript-eslint/typescript-estree": "6.7.4", + "semver": "^7.5.4" } }, "@typescript-eslint/visitor-keys": { - "version": "5.54.1", - "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.54.1.tgz", - "integrity": "sha512-q8iSoHTgwCfgcRJ2l2x+xCbu8nBlRAlsQ33k24Adj8eoVBE0f8dUeI+bAa8F84Mv05UGbAx57g2zrRsYIooqQg==", - "dev": true, - "requires": { - "@typescript-eslint/types": "5.54.1", - "eslint-visitor-keys": "^3.3.0" - } - }, - "@webassemblyjs/ast": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.1.tgz", - "integrity": "sha512-ukBh14qFLjxTQNTXocdyksN5QdM28S1CxHt2rdskFyL+xFV7VremuBLVbmCePj+URalXBENx/9Lm7lnhihtCSw==", - "dev": true, - "requires": { - "@webassemblyjs/helper-numbers": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1" - } - }, - "@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.1.tgz", - "integrity": "sha512-iGRfyc5Bq+NnNuX8b5hwBrRjzf0ocrJPI6GWFodBFzmFnyvrQ83SHKhmilCU/8Jv67i4GJZBMhEzltxzcNagtQ==", - "dev": true - }, - "@webassemblyjs/helper-api-error": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.1.tgz", - "integrity": "sha512-RlhS8CBCXfRUR/cwo2ho9bkheSXG0+NwooXcc3PAILALf2QLdFyj7KGsKRbVc95hZnhnERon4kW/D3SZpp6Tcg==", - "dev": true - }, - "@webassemblyjs/helper-buffer": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.1.tgz", - "integrity": "sha512-gwikF65aDNeeXa8JxXa2BAk+REjSyhrNC9ZwdT0f8jc4dQQeDQ7G4m0f2QCLPJiMTTO6wfDmRmj/pW0PsUvIcA==", - "dev": true - }, - "@webassemblyjs/helper-numbers": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.1.tgz", - "integrity": "sha512-vDkbxiB8zfnPdNK9Rajcey5C0w+QJugEglN0of+kmO8l7lDb77AnlKYQF7aarZuCrv+l0UvqL+68gSDr3k9LPQ==", - "dev": true, - "requires": { - "@webassemblyjs/floating-point-hex-parser": "1.11.1", - "@webassemblyjs/helper-api-error": "1.11.1", - "@xtuc/long": "4.2.2" - } - }, - "@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.1.tgz", - "integrity": "sha512-PvpoOGiJwXeTrSf/qfudJhwlvDQxFgelbMqtq52WWiXC6Xgg1IREdngmPN3bs4RoO83PnL/nFrxucXj1+BX62Q==", - "dev": true - }, - "@webassemblyjs/helper-wasm-section": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.1.tgz", - "integrity": "sha512-10P9No29rYX1j7F3EVPX3JvGPQPae+AomuSTPiF9eBQeChHI6iqjMIwR9JmOJXwpnn/oVGDk7I5IlskuMwU/pg==", - "dev": true, - "requires": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1" - } - }, - "@webassemblyjs/ieee754": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.1.tgz", - "integrity": "sha512-hJ87QIPtAMKbFq6CGTkZYJivEwZDbQUgYd3qKSadTNOhVY7p+gfP6Sr0lLRVTaG1JjFj+r3YchoqRYxNH3M0GQ==", - "dev": true, - "requires": { - "@xtuc/ieee754": "^1.2.0" - } - }, - "@webassemblyjs/leb128": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.1.tgz", - "integrity": "sha512-BJ2P0hNZ0u+Th1YZXJpzW6miwqQUGcIHT1G/sf72gLVD9DZ5AdYTqPNbHZh6K1M5VmKvFXwGSWZADz+qBWxeRw==", - "dev": true, - "requires": { - "@xtuc/long": "4.2.2" - } - }, - "@webassemblyjs/utf8": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.1.tgz", - "integrity": "sha512-9kqcxAEdMhiwQkHpkNiorZzqpGrodQQ2IGrHHxCy+Ozng0ofyMA0lTqiLkVs1uzTRejX+/O0EOT7KxqVPuXosQ==", - "dev": true - }, - "@webassemblyjs/wasm-edit": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.1.tgz", - "integrity": "sha512-g+RsupUC1aTHfR8CDgnsVRVZFJqdkFHpsHMfJuWQzWU3tvnLC07UqHICfP+4XyL2tnr1amvl1Sdp06TnYCmVkA==", - "dev": true, - "requires": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/helper-wasm-section": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1", - "@webassemblyjs/wasm-opt": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1", - "@webassemblyjs/wast-printer": "1.11.1" - } - }, - "@webassemblyjs/wasm-gen": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.1.tgz", - "integrity": "sha512-F7QqKXwwNlMmsulj6+O7r4mmtAlCWfO/0HdgOxSklZfQcDu0TpLiD1mRt/zF25Bk59FIjEuGAIyn5ei4yMfLhA==", - "dev": true, - "requires": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/ieee754": "1.11.1", - "@webassemblyjs/leb128": "1.11.1", - "@webassemblyjs/utf8": "1.11.1" - } - }, - "@webassemblyjs/wasm-opt": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.1.tgz", - "integrity": "sha512-VqnkNqnZlU5EB64pp1l7hdm3hmQw7Vgqa0KF/KCNO9sIpI6Fk6brDEiX+iCOYrvMuBWDws0NkTOxYEb85XQHHw==", - "dev": true, - "requires": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-buffer": "1.11.1", - "@webassemblyjs/wasm-gen": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1" - } - }, - "@webassemblyjs/wasm-parser": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.1.tgz", - "integrity": "sha512-rrBujw+dJu32gYB7/Lup6UhdkPx9S9SnobZzRVL7VcBH9Bt9bCBLEuX/YXOOtBsOZ4NQrRykKhffRWHvigQvOA==", - "dev": true, - "requires": { - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/helper-api-error": "1.11.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.1", - "@webassemblyjs/ieee754": "1.11.1", - "@webassemblyjs/leb128": "1.11.1", - "@webassemblyjs/utf8": "1.11.1" - } - }, - "@webassemblyjs/wast-printer": { - "version": "1.11.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.1.tgz", - "integrity": "sha512-IQboUWM4eKzWW+N/jij2sRatKMh99QEelo3Eb2q0qXkvPRISAj8Qxtmw5itwqK+TTkBuUIE45AxYPToqPtL5gg==", + "version": "6.7.4", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.7.4.tgz", + "integrity": "sha512-pOW37DUhlTZbvph50x5zZCkFn3xzwkGtNoJHzIM3svpiSkJzwOYr/kVBaXmf+RAQiUDs1AHEZVNPg6UJCJpwRA==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.1", - "@xtuc/long": "4.2.2" + "@typescript-eslint/types": "6.7.4", + "eslint-visitor-keys": "^3.4.1" } }, - "@webpack-cli/configtest": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.0.1.tgz", - "integrity": "sha512-njsdJXJSiS2iNbQVS0eT8A/KPnmyH4pv1APj2K0d1wrZcBLw+yppxOy4CGqa0OxDJkzfL/XELDhD8rocnIwB5A==", - "dev": true, - "requires": {} - }, - "@webpack-cli/info": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/info/-/info-2.0.1.tgz", - "integrity": "sha512-fE1UEWTwsAxRhrJNikE7v4EotYflkEhBL7EbajfkPlf6E37/2QshOy/D48Mw8G5XMFlQtS6YV42vtbG9zBpIQA==", - "dev": true, - "requires": {} - }, - "@webpack-cli/serve": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@webpack-cli/serve/-/serve-2.0.1.tgz", - "integrity": "sha512-0G7tNyS+yW8TdgHwZKlDWYXFA6OJQnoLCQvYKkQP0Q2X205PSQ6RNUj0M+1OB/9gRQaUZ/ccYfaxd0nhaWKfjw==", - "dev": true, - "requires": {} - }, - "@xtuc/ieee754": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@xtuc/ieee754/-/ieee754-1.2.0.tgz", - "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", - "dev": true - }, - "@xtuc/long": { - "version": "4.2.2", - "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", - "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", - "dev": true - }, "abort-controller": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", @@ -6772,18 +5180,11 @@ } }, "acorn": { - "version": "8.8.2", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz", - "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==", + "version": "8.10.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", + "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", "dev": true }, - "acorn-import-assertions": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.8.0.tgz", - "integrity": "sha512-m7VZ3jwz4eK6A4Vtt8Ew1/mNbP24u0FhdyfA7fSvnJR6LMdfOYnmuIrrJAgrYfYJ10F/otaHTtrtrtmHdMNzEw==", - "dev": true, - "requires": {} - }, "acorn-jsx": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", @@ -6791,12 +5192,6 @@ "dev": true, "requires": {} }, - "acorn-walk": { - "version": "8.2.0", - "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz", - "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==", - "dev": true - }, "ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -6809,13 +5204,6 @@ "uri-js": "^4.2.2" } }, - "ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", - "dev": true, - "requires": {} - }, "ansi-colors": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/ansi-colors/-/ansi-colors-4.1.1.tgz", @@ -6853,6 +5241,12 @@ "integrity": "sha512-lYe4Gx7QT+MKGbDsA+Z+he/Wtef0BiwDOlK/XkBrdfsh9J/jPPXbX0tE9x9cl27Tmu5gg3QUbUrQYa/y+KOHPQ==", "dev": true }, + "are-docs-informative": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/are-docs-informative/-/are-docs-informative-0.0.2.tgz", + "integrity": "sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==", + "dev": true + }, "are-we-there-yet": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-4.0.0.tgz", @@ -6883,6 +5277,16 @@ "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", "dev": true }, + "array-buffer-byte-length": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/array-buffer-byte-length/-/array-buffer-byte-length-1.0.0.tgz", + "integrity": "sha512-LPuwb2P+NrQw3XhxGc36+XSvuBPopovXYTR9Ew++Du9Yb/bx5AzBfrIsBoj0EZUifjQU+sHL21sseZ3jerWO/A==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "is-array-buffer": "^3.0.1" + } + }, "array-includes": { "version": "3.1.6", "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.6.tgz", @@ -6902,6 +5306,19 @@ "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", "dev": true }, + "array.prototype.findlastindex": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/array.prototype.findlastindex/-/array.prototype.findlastindex-1.2.3.tgz", + "integrity": "sha512-LzLoiOMAxvy+Gd3BAq3B7VeIgPdo+Q8hthvKtXybMvRV0jrXfJM/t8mw7nNlpEcVlVUnCnM2KSX4XU5HmpodOA==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "es-shim-unscopables": "^1.0.0", + "get-intrinsic": "^1.2.1" + } + }, "array.prototype.flat": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/array.prototype.flat/-/array.prototype.flat-1.3.1.tgz", @@ -6926,36 +5343,19 @@ "es-shim-unscopables": "^1.0.0" } }, - "asn1.js": { - "version": "5.4.1", - "resolved": "https://registry.npmjs.org/asn1.js/-/asn1.js-5.4.1.tgz", - "integrity": "sha512-+I//4cYPccV8LdmBLiX8CYvf9Sp3vQsrqu2QNXRcrbiWvcx/UdlFiqUJJzxRQxgsZmvhXhn4cSKeSmoFjVdupA==", - "dev": true, - "requires": { - "bn.js": "^4.0.0", - "inherits": "^2.0.1", - "minimalistic-assert": "^1.0.0", - "safer-buffer": "^2.1.0" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, - "assert": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/assert/-/assert-2.0.0.tgz", - "integrity": "sha512-se5Cd+js9dXJnu6Ag2JFc00t+HmHOen+8Q+L7O9zI0PqQXr20uk2J0XQqMxZEeo5U50o8Nvmmx7dZrl+Ufr35A==", + "arraybuffer.prototype.slice": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/arraybuffer.prototype.slice/-/arraybuffer.prototype.slice-1.0.2.tgz", + "integrity": "sha512-yMBKppFur/fbHu9/6USUe03bZ4knMYiwFBcyiaXB8Go0qNehwX6inYPzK9U0NeQvGxKthcmHcaR8P5MStSRBAw==", "dev": true, "requires": { - "es6-object-assign": "^1.1.0", - "is-nan": "^1.2.1", - "object-is": "^1.0.1", - "util": "^0.12.0" + "array-buffer-byte-length": "^1.0.0", + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "get-intrinsic": "^1.2.1", + "is-array-buffer": "^3.0.2", + "is-shared-array-buffer": "^1.0.2" } }, "async": { @@ -6982,24 +5382,12 @@ "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", "dev": true }, - "big.js": { - "version": "5.2.2", - "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", - "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", - "dev": true - }, "binary-extensions": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", "dev": true }, - "bn.js": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-5.2.1.tgz", - "integrity": "sha512-eXRvHzWyYPBuB4NBy0cmYQjGitUrtqwbvlzP3G6VFnNRbsZQIxQ10PbKKHt8gZ/HW/D/747aDl+QkDqg3KQLMQ==", - "dev": true - }, "brace-expansion": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", @@ -7019,122 +5407,12 @@ "fill-range": "^7.0.1" } }, - "brorand": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/brorand/-/brorand-1.1.0.tgz", - "integrity": "sha512-cKV8tMCEpQs4hK/ik71d6LrPOnpkpGBR0wzxqr68g2m/LB2GxVYQroAjMJZRVM1Y4BCjCKc3vAamxSzOY2RP+w==", - "dev": true - }, "browser-stdout": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/browser-stdout/-/browser-stdout-1.3.1.tgz", "integrity": "sha512-qhAVI1+Av2X7qelOfAIYwXONood6XlZE/fXaBSmW/T5SzLAmCgzi+eiWE7fUvbHaeNBQH13UftjpXxsfLkMpgw==", "dev": true }, - "browserify-aes": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/browserify-aes/-/browserify-aes-1.2.0.tgz", - "integrity": "sha512-+7CHXqGuspUn/Sl5aO7Ea0xWGAtETPXNSAjHo48JfLdPWcMng33Xe4znFvQweqc/uzk5zSOI3H52CYnjCfb5hA==", - "dev": true, - "requires": { - "buffer-xor": "^1.0.3", - "cipher-base": "^1.0.0", - "create-hash": "^1.1.0", - "evp_bytestokey": "^1.0.3", - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - } - }, - "browserify-cipher": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/browserify-cipher/-/browserify-cipher-1.0.1.tgz", - "integrity": "sha512-sPhkz0ARKbf4rRQt2hTpAHqn47X3llLkUGn+xEJzLjwY8LRs2p0v7ljvI5EyoRO/mexrNunNECisZs+gw2zz1w==", - "dev": true, - "requires": { - "browserify-aes": "^1.0.4", - "browserify-des": "^1.0.0", - "evp_bytestokey": "^1.0.0" - } - }, - "browserify-des": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/browserify-des/-/browserify-des-1.0.2.tgz", - "integrity": "sha512-BioO1xf3hFwz4kc6iBhI3ieDFompMhrMlnDFC4/0/vd5MokpuAc3R+LYbwTA9A5Yc9pq9UYPqffKpW2ObuwX5A==", - "dev": true, - "requires": { - "cipher-base": "^1.0.1", - "des.js": "^1.0.0", - "inherits": "^2.0.1", - "safe-buffer": "^5.1.2" - } - }, - "browserify-rsa": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/browserify-rsa/-/browserify-rsa-4.1.0.tgz", - "integrity": "sha512-AdEER0Hkspgno2aR97SAf6vi0y0k8NuOpGnVH3O99rcA5Q6sh8QxcngtHuJ6uXwnfAXNM4Gn1Gb7/MV1+Ymbog==", - "dev": true, - "requires": { - "bn.js": "^5.0.0", - "randombytes": "^2.0.1" - } - }, - "browserify-sign": { - "version": "4.2.1", - "resolved": "https://registry.npmjs.org/browserify-sign/-/browserify-sign-4.2.1.tgz", - "integrity": "sha512-/vrA5fguVAKKAVTNJjgSm1tRQDHUU6DbwO9IROu/0WAzC8PKhucDSh18J0RMvVeHAn5puMd+QHC2erPRNf8lmg==", - "dev": true, - "requires": { - "bn.js": "^5.1.1", - "browserify-rsa": "^4.0.1", - "create-hash": "^1.2.0", - "create-hmac": "^1.1.7", - "elliptic": "^6.5.3", - "inherits": "^2.0.4", - "parse-asn1": "^5.1.5", - "readable-stream": "^3.6.0", - "safe-buffer": "^5.2.0" - }, - "dependencies": { - "readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - } - }, - "safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true - } - } - }, - "browserify-zlib": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/browserify-zlib/-/browserify-zlib-0.2.0.tgz", - "integrity": "sha512-Z942RysHXmJrhqk88FmKBVq/v5tqmSkDz7p54G/MGyjMnCFFnC79XWNbg+Vta8W6Wb2qtSZTSxIGkJrRpCFEiA==", - "dev": true, - "requires": { - "pako": "~1.0.5" - } - }, - "browserslist": { - "version": "4.21.5", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.5.tgz", - "integrity": "sha512-tUkiguQGW7S3IhB7N+c2MV/HZPSCPAAiYBZXLsBhFB/PCy6ZKKsZrmBayHV9fdGV/ARIfJ14NkxKzRDjvp7L6w==", - "dev": true, - "requires": { - "caniuse-lite": "^1.0.30001449", - "electron-to-chromium": "^1.4.284", - "node-releases": "^2.0.8", - "update-browserslist-db": "^1.0.10" - } - }, "buffer": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", @@ -7145,30 +5423,12 @@ "ieee754": "^1.2.1" } }, - "buffer-from": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", - "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", - "dev": true - }, - "buffer-xor": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/buffer-xor/-/buffer-xor-1.0.3.tgz", - "integrity": "sha512-571s0T7nZWK6vB67HI5dyUF7wXiNcfaPPPTl6zYCNApANjIvYJTg7hlud/+cJpdAhS7dVzqMLmfhfHR3rAcOjQ==", - "dev": true - }, "builtin-modules": { "version": "3.3.0", "resolved": "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz", "integrity": "sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw==", "dev": true }, - "builtin-status-codes": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/builtin-status-codes/-/builtin-status-codes-3.0.0.tgz", - "integrity": "sha512-HpGFw18DgFWlncDfjTa2rcQ4W88O1mC8e8yZ2AvQY5KDaktSTwo+KRf6nHK6FRI5FyRyb/5T6+TSxfP7QyGsmQ==", - "dev": true - }, "call-bind": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", @@ -7191,12 +5451,6 @@ "integrity": "sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==", "dev": true }, - "caniuse-lite": { - "version": "1.0.30001460", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001460.tgz", - "integrity": "sha512-Bud7abqjvEjipUkpLs4D7gR0l8hBYBHoa+tGtKJHvT2AYzLp1z7EmVkUT4ERpVUfca8S2HGIVs883D8pUH1ZzQ==", - "dev": true - }, "chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -7234,28 +5488,12 @@ } } }, - "chrome-trace-event": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.3.tgz", - "integrity": "sha512-p3KULyQg4S7NIHixdwbGX+nFHkoBiA4YQmyWtjb8XngSKV124nJmRysgAeujbUVb15vh+RvFUfCPqU7rXk+hZg==", - "dev": true - }, "ci-info": { "version": "3.8.0", "resolved": "https://registry.npmjs.org/ci-info/-/ci-info-3.8.0.tgz", "integrity": "sha512-eXTggHWSooYhq49F2opQhuHWgzucfF2YgODK4e1566GQs5BIfP30B0oenwBJHfWxAs2fyPB1s7Mg949zLf61Yw==", "dev": true }, - "cipher-base": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/cipher-base/-/cipher-base-1.0.4.tgz", - "integrity": "sha512-Kkht5ye6ZGmwv40uUDZztayT2ThLQGfnj/T71N/XzeZeo3nf8foyW7zGTsPYkEya3m5f3cAypH+qe7YOrM1U2Q==", - "dev": true, - "requires": { - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - } - }, "clang-format": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/clang-format/-/clang-format-1.8.0.tgz", @@ -7295,17 +5533,6 @@ "wrap-ansi": "^7.0.0" } }, - "clone-deep": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", - "integrity": "sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==", - "dev": true, - "requires": { - "is-plain-object": "^2.0.4", - "kind-of": "^6.0.2", - "shallow-clone": "^3.0.0" - } - }, "color-convert": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", @@ -7327,22 +5554,10 @@ "integrity": "sha512-qiBjkpbMLO/HL68y+lh4q0/O1MZFj2RX6X/KmMa3+gJD3z+WwI1ZzDHysvqHGS3mP6mznPckpXmw1nI9cJjyRg==", "dev": true }, - "colorette": { - "version": "2.0.19", - "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.19.tgz", - "integrity": "sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==", - "dev": true - }, - "commander": { - "version": "2.20.3", - "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", - "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", - "dev": true - }, "comment-parser": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.3.1.tgz", - "integrity": "sha512-B52sN2VNghyq5ofvUsqZjmk6YkihBX5vMSChmSK9v4ShjKf3Vk5Xcmgpw4o+iIgtrnM/u5FiMpz9VKb8lpBveA==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.4.0.tgz", + "integrity": "sha512-QLyTNiZ2KDOibvFPlZ6ZngVsZ/0gYnE6uTXi5aoDg8ed3AkJAz4sEje3Y8a29hQ1s6A99MZXe47fLAXQ1rTqaw==", "dev": true }, "concat-map": { @@ -7351,75 +5566,18 @@ "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", "dev": true }, - "console-browserify": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/console-browserify/-/console-browserify-1.2.0.tgz", - "integrity": "sha512-ZMkYO/LkF17QvCPqM0gxw8yUzigAOZOSWSHg91FH6orS7vcEj5dVZTidN2fQ14yBSdg97RqhSNwLUXInd52OTA==", - "dev": true - }, "console-control-strings": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==", "dev": true }, - "constants-browserify": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/constants-browserify/-/constants-browserify-1.0.0.tgz", - "integrity": "sha512-xFxOwqIzR/e1k1gLiWEophSCMqXcwVHIH7akf7b/vxcUeGunlj3hvZaaqxwHsTgn+IndtkQJgSztIDWeumWJDQ==", - "dev": true - }, "core-util-is": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", "dev": true }, - "create-ecdh": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/create-ecdh/-/create-ecdh-4.0.4.tgz", - "integrity": "sha512-mf+TCx8wWc9VpuxfP2ht0iSISLZnt0JgWlrOKZiNqyUZWnjIaCIVNQArMHnCZKfEYRg6IM7A+NeJoN8gf/Ws0A==", - "dev": true, - "requires": { - "bn.js": "^4.1.0", - "elliptic": "^6.5.3" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, - "create-hash": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/create-hash/-/create-hash-1.2.0.tgz", - "integrity": "sha512-z00bCGNHDG8mHAkP7CtT1qVu+bFQUPjYq/4Iv3C3kWjTFV10zIjfSoeqXo9Asws8gwSHDGj/hl2u4OGIjapeCg==", - "dev": true, - "requires": { - "cipher-base": "^1.0.1", - "inherits": "^2.0.1", - "md5.js": "^1.3.4", - "ripemd160": "^2.0.1", - "sha.js": "^2.4.0" - } - }, - "create-hmac": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/create-hmac/-/create-hmac-1.1.7.tgz", - "integrity": "sha512-MJG9liiZ+ogc4TzUwuvbER1JRdgvUFSB5+VR/g5h82fGaIRWMWddtKBHi7/sVhfjQZ6SehlyhvQYrcYkaUIpLg==", - "dev": true, - "requires": { - "cipher-base": "^1.0.3", - "create-hash": "^1.1.0", - "inherits": "^2.0.1", - "ripemd160": "^2.0.0", - "safe-buffer": "^5.0.1", - "sha.js": "^2.4.8" - } - }, "cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -7431,25 +5589,6 @@ "which": "^2.0.1" } }, - "crypto-browserify": { - "version": "3.12.0", - "resolved": "https://registry.npmjs.org/crypto-browserify/-/crypto-browserify-3.12.0.tgz", - "integrity": "sha512-fz4spIh+znjO2VjL+IdhEpRJ3YN6sMzITSBijk6FK2UvTqruSQW+/cCZTSNsMiZNvUeq0CqurF+dAbyiGOY6Wg==", - "dev": true, - "requires": { - "browserify-cipher": "^1.0.0", - "browserify-sign": "^4.0.0", - "create-ecdh": "^4.0.0", - "create-hash": "^1.1.0", - "create-hmac": "^1.1.0", - "diffie-hellman": "^5.0.0", - "inherits": "^2.0.1", - "pbkdf2": "^3.0.3", - "public-encrypt": "^4.0.0", - "randombytes": "^2.0.0", - "randomfill": "^1.0.3" - } - }, "debug": { "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", @@ -7471,6 +5610,17 @@ "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", "dev": true }, + "define-data-property": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.0.tgz", + "integrity": "sha512-UzGwzcjyv3OtAvolTj1GoyNYzfFR+iqbGjcnBEENZVCpM4/Ng1yhGNvS3lR/xDS74Tb2wGG9WzNSNIOS9UVb2g==", + "dev": true, + "requires": { + "get-intrinsic": "^1.2.1", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.0" + } + }, "define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -7487,48 +5637,19 @@ "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", "dev": true }, - "des.js": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/des.js/-/des.js-1.0.1.tgz", - "integrity": "sha512-Q0I4pfFrv2VPd34/vfLrFOoRmlYj3OV50i7fskps1jZWK1kApMWWT9G6RRUeYedLcBDIhnSDaUvJMb3AhUlaEA==", - "dev": true, - "requires": { - "inherits": "^2.0.1", - "minimalistic-assert": "^1.0.0" - } - }, "diff": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz", "integrity": "sha512-/VTCrvm5Z0JGty/BWHljh+BAiw3IK+2j87NGMu8Nwc/f48WoDAC395uomO9ZD117ZOBaHmkX1oyLvkVM/aIT3w==", "dev": true }, - "diffie-hellman": { - "version": "5.0.3", - "resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz", - "integrity": "sha512-kqag/Nl+f3GwyK25fhUMYj81BUOrZ9IuJsjIcDE5icNM9FJHAVm3VcUDxdLPoQtTuUylWm6ZIknYJwwaPxsUzg==", - "dev": true, - "requires": { - "bn.js": "^4.1.0", - "miller-rabin": "^4.0.0", - "randombytes": "^2.0.0" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, "dir-compare": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/dir-compare/-/dir-compare-4.0.0.tgz", - "integrity": "sha512-wC7thVKL3V656tO61rbEDE4LTeeYrUC2pAUL00AaXYghBhjjVNRyBlpH6POzb44ZuK23OSrqF6TbSC/QYeqfAg==", + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/dir-compare/-/dir-compare-4.2.0.tgz", + "integrity": "sha512-2xMCmOoMrdQIPHdsTawECdNPwlVFB9zGcz3kuhmBO6U3oU+UQjsue0i8ayLKpgBcm+hcXPMVSGUN9d+pvJ6+VQ==", "dev": true, "requires": { - "minimatch": "^3.0.4", + "minimatch": "^3.0.5", "p-limit": "^3.1.0 " } }, @@ -7550,75 +5671,12 @@ "esutils": "^2.0.2" } }, - "domain-browser": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/domain-browser/-/domain-browser-4.22.0.tgz", - "integrity": "sha512-IGBwjF7tNk3cwypFNH/7bfzBcgSCbaMOD3GsaY1AU/JRrnHnYgEM0+9kQt52iZxjNsjBtJYtao146V+f8jFZNw==", - "dev": true - }, - "duplexer": { - "version": "0.1.2", - "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", - "integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==", - "dev": true - }, - "electron-to-chromium": { - "version": "1.4.320", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.320.tgz", - "integrity": "sha512-h70iRscrNluMZPVICXYl5SSB+rBKo22XfuIS1ER0OQxQZpKTnFpuS6coj7wY9M/3trv7OR88rRMOlKmRvDty7Q==", - "dev": true - }, - "elliptic": { - "version": "6.5.4", - "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.5.4.tgz", - "integrity": "sha512-iLhC6ULemrljPZb+QutR5TQGB+pdW6KGD5RSegS+8sorOZT+rdQFbsQFJgvN3eRqNALqJer4oQ16YvJHlU8hzQ==", - "dev": true, - "requires": { - "bn.js": "^4.11.9", - "brorand": "^1.1.0", - "hash.js": "^1.0.0", - "hmac-drbg": "^1.0.1", - "inherits": "^2.0.4", - "minimalistic-assert": "^1.0.1", - "minimalistic-crypto-utils": "^1.0.1" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, "emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", "dev": true }, - "emojis-list": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", - "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", - "dev": true - }, - "enhanced-resolve": { - "version": "5.12.0", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.12.0.tgz", - "integrity": "sha512-QHTXI/sZQmko1cbDoNAa3mJ5qhWUUNAq3vR0/YiD379fWQrcfuoX1+HW2S0MTt7XmoPLapdaDKUtelUSPic7hQ==", - "dev": true, - "requires": { - "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" - } - }, - "envinfo": { - "version": "7.8.1", - "resolved": "https://registry.npmjs.org/envinfo/-/envinfo-7.8.1.tgz", - "integrity": "sha512-/o+BXHmB7ocbHEAs6F2EnG0ogybVVUdkRunTT2glZU9XAaGmhqskrvKwqXuDfNjEO0LZKWdejEEpnq8aM0tOaw==", - "dev": true - }, "error-ex": { "version": "1.3.2", "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", @@ -7629,18 +5687,19 @@ } }, "es-abstract": { - "version": "1.21.1", - "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.21.1.tgz", - "integrity": "sha512-QudMsPOz86xYz/1dG1OuGBKOELjCh99IIWHLzy5znUB6j8xG2yMA7bfTV86VSqKF+Y/H08vQPR+9jyXpuC6hfg==", + "version": "1.22.2", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.22.2.tgz", + "integrity": "sha512-YoxfFcDmhjOgWPWsV13+2RNjq1F6UQnfs+8TftwNqtzlmFzEXvlUwdrNrYeaizfjQzRMxkZ6ElWMOJIFKdVqwA==", "dev": true, "requires": { + "array-buffer-byte-length": "^1.0.0", + "arraybuffer.prototype.slice": "^1.0.2", "available-typed-arrays": "^1.0.5", "call-bind": "^1.0.2", "es-set-tostringtag": "^2.0.1", "es-to-primitive": "^1.2.1", - "function-bind": "^1.1.1", - "function.prototype.name": "^1.1.5", - "get-intrinsic": "^1.1.3", + "function.prototype.name": "^1.1.6", + "get-intrinsic": "^1.2.1", "get-symbol-description": "^1.0.0", "globalthis": "^1.0.3", "gopd": "^1.0.1", @@ -7648,33 +5707,32 @@ "has-property-descriptors": "^1.0.0", "has-proto": "^1.0.1", "has-symbols": "^1.0.3", - "internal-slot": "^1.0.4", - "is-array-buffer": "^3.0.1", + "internal-slot": "^1.0.5", + "is-array-buffer": "^3.0.2", "is-callable": "^1.2.7", "is-negative-zero": "^2.0.2", "is-regex": "^1.1.4", "is-shared-array-buffer": "^1.0.2", "is-string": "^1.0.7", - "is-typed-array": "^1.1.10", + "is-typed-array": "^1.1.12", "is-weakref": "^1.0.2", - "object-inspect": "^1.12.2", + "object-inspect": "^1.12.3", "object-keys": "^1.1.1", "object.assign": "^4.1.4", - "regexp.prototype.flags": "^1.4.3", + "regexp.prototype.flags": "^1.5.1", + "safe-array-concat": "^1.0.1", "safe-regex-test": "^1.0.0", - "string.prototype.trimend": "^1.0.6", - "string.prototype.trimstart": "^1.0.6", + "string.prototype.trim": "^1.2.8", + "string.prototype.trimend": "^1.0.7", + "string.prototype.trimstart": "^1.0.7", + "typed-array-buffer": "^1.0.0", + "typed-array-byte-length": "^1.0.0", + "typed-array-byte-offset": "^1.0.0", "typed-array-length": "^1.0.4", "unbox-primitive": "^1.0.2", - "which-typed-array": "^1.1.9" + "which-typed-array": "^1.1.11" } }, - "es-module-lexer": { - "version": "0.9.3", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-0.9.3.tgz", - "integrity": "sha512-1HQ2M2sPtxwnvOvT1ZClHyQDiggdNjURWpY2we6aMKCQiUVxTmVs2UYPLIrD84sS+kMdUwfBSylbJPwNnBrnHQ==", - "dev": true - }, "es-set-tostringtag": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.0.1.tgz", @@ -7706,11 +5764,45 @@ "is-symbol": "^1.0.2" } }, - "es6-object-assign": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/es6-object-assign/-/es6-object-assign-1.1.0.tgz", - "integrity": "sha512-MEl9uirslVwqQU369iHNWZXsI8yaZYGg/D65aOgZkeyFJwHYSxilf7rQzXKI7DdDuBPrBXbfk3sl9hJhmd5AUw==", - "dev": true + "esbuild": { + "version": "0.19.3", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz", + "integrity": "sha512-UlJ1qUUA2jL2nNib1JTSkifQTcYTroFqRjwCFW4QYEKEsixXD5Tik9xML7zh2gTxkYTBKGHNH9y7txMwVyPbjw==", + "dev": true, + "requires": { + "@esbuild/android-arm": "0.19.3", + "@esbuild/android-arm64": "0.19.3", + "@esbuild/android-x64": "0.19.3", + "@esbuild/darwin-arm64": "0.19.3", + "@esbuild/darwin-x64": "0.19.3", + "@esbuild/freebsd-arm64": "0.19.3", + "@esbuild/freebsd-x64": "0.19.3", + "@esbuild/linux-arm": "0.19.3", + "@esbuild/linux-arm64": "0.19.3", + "@esbuild/linux-ia32": "0.19.3", + "@esbuild/linux-loong64": "0.19.3", + "@esbuild/linux-mips64el": "0.19.3", + "@esbuild/linux-ppc64": "0.19.3", + "@esbuild/linux-riscv64": "0.19.3", + "@esbuild/linux-s390x": "0.19.3", + "@esbuild/linux-x64": "0.19.3", + "@esbuild/netbsd-x64": "0.19.3", + "@esbuild/openbsd-x64": "0.19.3", + "@esbuild/sunos-x64": "0.19.3", + "@esbuild/win32-arm64": "0.19.3", + "@esbuild/win32-ia32": "0.19.3", + "@esbuild/win32-x64": "0.19.3" + } + }, + "esbuild-plugin-polyfill-node": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/esbuild-plugin-polyfill-node/-/esbuild-plugin-polyfill-node-0.3.0.tgz", + "integrity": "sha512-SHG6CKUfWfYyYXGpW143NEZtcVVn8S/WHcEOxk62LuDXnY4Zpmc+WmxJKN6GMTgTClXJXhEM5KQlxKY6YjbucQ==", + "dev": true, + "requires": { + "@jspm/core": "^2.0.1", + "import-meta-resolve": "^3.0.0" + } }, "escalade": { "version": "3.1.1", @@ -7725,26 +5817,27 @@ "dev": true }, "eslint": { - "version": "8.35.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.35.0.tgz", - "integrity": "sha512-BxAf1fVL7w+JLRQhWl2pzGeSiGqbWumV4WNvc9Rhp6tiCtm4oHnyPBSEtMGZwrQgudFQ+otqzWoPB7x+hxoWsw==", + "version": "8.51.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.51.0.tgz", + "integrity": "sha512-2WuxRZBrlwnXi+/vFSJyjMqrNjtJqiasMzehF0shoLaW7DzS3/9Yvrmq5JiT66+pNjiX4UBnLDiKHcWAr/OInA==", "dev": true, "requires": { - "@eslint/eslintrc": "^2.0.0", - "@eslint/js": "8.35.0", - "@humanwhocodes/config-array": "^0.11.8", + "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/regexpp": "^4.6.1", + "@eslint/eslintrc": "^2.1.2", + "@eslint/js": "8.51.0", + "@humanwhocodes/config-array": "^0.11.11", "@humanwhocodes/module-importer": "^1.0.1", "@nodelib/fs.walk": "^1.2.8", - "ajv": "^6.10.0", + "ajv": "^6.12.4", "chalk": "^4.0.0", "cross-spawn": "^7.0.2", "debug": "^4.3.2", "doctrine": "^3.0.0", "escape-string-regexp": "^4.0.0", - "eslint-scope": "^7.1.1", - "eslint-utils": "^3.0.0", - "eslint-visitor-keys": "^3.3.0", - "espree": "^9.4.0", + "eslint-scope": "^7.2.2", + "eslint-visitor-keys": "^3.4.3", + "espree": "^9.6.1", "esquery": "^1.4.2", "esutils": "^2.0.2", "fast-deep-equal": "^3.1.3", @@ -7752,42 +5845,20 @@ "find-up": "^5.0.0", "glob-parent": "^6.0.2", "globals": "^13.19.0", - "grapheme-splitter": "^1.0.4", + "graphemer": "^1.4.0", "ignore": "^5.2.0", - "import-fresh": "^3.0.0", "imurmurhash": "^0.1.4", "is-glob": "^4.0.0", "is-path-inside": "^3.0.3", - "js-sdsl": "^4.1.4", "js-yaml": "^4.1.0", "json-stable-stringify-without-jsonify": "^1.0.1", "levn": "^0.4.1", "lodash.merge": "^4.6.2", "minimatch": "^3.1.2", "natural-compare": "^1.4.0", - "optionator": "^0.9.1", - "regexpp": "^3.2.0", + "optionator": "^0.9.3", "strip-ansi": "^6.0.1", - "strip-json-comments": "^3.1.0", "text-table": "^0.2.0" - }, - "dependencies": { - "eslint-scope": { - "version": "7.1.1", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.1.1.tgz", - "integrity": "sha512-QKQM/UXpIiHcLqJ5AOyIW7XZmzjkzQXYE54n1++wb0u9V/abW3l9uQnxX8Z5Xd18xyKIMTUAyQ0k1e8pz6LUrw==", - "dev": true, - "requires": { - "esrecurse": "^4.3.0", - "estraverse": "^5.2.0" - } - }, - "estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", - "dev": true - } } }, "eslint-import-resolver-node": { @@ -7813,9 +5884,9 @@ } }, "eslint-module-utils": { - "version": "2.7.4", - "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.7.4.tgz", - "integrity": "sha512-j4GT+rqzCoRKHwURX7pddtIPGySnX9Si/cgMI5ztrcqOPtk5dDEeZ34CQVPphnqkJytlc97Vuk05Um2mJ3gEQA==", + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.8.0.tgz", + "integrity": "sha512-aWajIYfsqCKRDgUfjEXNN/JlrzauMuSEy5sbd7WXbtW3EH6A6MpwEh42c7qD+MqQo9QMJ6fWLAeIJynx0g6OAw==", "dev": true, "requires": { "debug": "^3.2.7" @@ -7840,26 +5911,28 @@ "requires": {} }, "eslint-plugin-import": { - "version": "2.27.5", - "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.27.5.tgz", - "integrity": "sha512-LmEt3GVofgiGuiE+ORpnvP+kAm3h6MLZJ4Q5HCyHADofsb4VzXFsRiWj3c0OFiV+3DWFh0qg3v9gcPlfc3zRow==", + "version": "2.28.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.28.1.tgz", + "integrity": "sha512-9I9hFlITvOV55alzoKBI+K9q74kv0iKMeY6av5+umsNwayt59fz692daGyjR+oStBQgx6nwR9rXldDev3Clw+A==", "dev": true, "requires": { "array-includes": "^3.1.6", + "array.prototype.findlastindex": "^1.2.2", "array.prototype.flat": "^1.3.1", "array.prototype.flatmap": "^1.3.1", "debug": "^3.2.7", "doctrine": "^2.1.0", "eslint-import-resolver-node": "^0.3.7", - "eslint-module-utils": "^2.7.4", + "eslint-module-utils": "^2.8.0", "has": "^1.0.3", - "is-core-module": "^2.11.0", + "is-core-module": "^2.13.0", "is-glob": "^4.0.3", "minimatch": "^3.1.2", + "object.fromentries": "^2.0.6", + "object.groupby": "^1.0.0", "object.values": "^1.1.6", - "resolve": "^1.22.1", - "semver": "^6.3.0", - "tsconfig-paths": "^3.14.1" + "semver": "^6.3.1", + "tsconfig-paths": "^3.14.2" }, "dependencies": { "debug": { @@ -7881,25 +5954,27 @@ } }, "semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true } } }, "eslint-plugin-jsdoc": { - "version": "40.0.1", - "resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-40.0.1.tgz", - "integrity": "sha512-KkiRInury7YrjjV5aCHDxwsPy6XFt5p2b2CnpDMITnWs8patNPf5kj24+VXIWw45kP6z/B0GOKfrYczB56OjQQ==", + "version": "46.8.2", + "resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-46.8.2.tgz", + "integrity": "sha512-5TSnD018f3tUJNne4s4gDWQflbsgOycIKEUBoCLn6XtBMgNHxQFmV8vVxUtiPxAQq8lrX85OaSG/2gnctxw9uQ==", "dev": true, "requires": { - "@es-joy/jsdoccomment": "~0.36.1", - "comment-parser": "1.3.1", + "@es-joy/jsdoccomment": "~0.40.1", + "are-docs-informative": "^0.0.2", + "comment-parser": "1.4.0", "debug": "^4.3.4", "escape-string-regexp": "^4.0.0", - "esquery": "^1.4.0", - "semver": "^7.3.8", + "esquery": "^1.5.0", + "is-builtin-module": "^3.2.1", + "semver": "^7.5.4", "spdx-expression-parse": "^3.0.1" } }, @@ -7911,71 +5986,53 @@ "requires": {} }, "eslint-plugin-unicorn": { - "version": "46.0.0", - "resolved": "https://registry.npmjs.org/eslint-plugin-unicorn/-/eslint-plugin-unicorn-46.0.0.tgz", - "integrity": "sha512-j07WkC+PFZwk8J33LYp6JMoHa1lXc1u6R45pbSAipjpfpb7KIGr17VE2D685zCxR5VL4cjrl65kTJflziQWMDA==", + "version": "48.0.1", + "resolved": "https://registry.npmjs.org/eslint-plugin-unicorn/-/eslint-plugin-unicorn-48.0.1.tgz", + "integrity": "sha512-FW+4r20myG/DqFcCSzoumaddKBicIPeFnTrifon2mWIzlfyvzwyqZjqVP7m4Cqr/ZYisS2aiLghkUWaPg6vtCw==", "dev": true, "requires": { - "@babel/helper-validator-identifier": "^7.19.1", - "@eslint-community/eslint-utils": "^4.1.2", - "ci-info": "^3.6.1", + "@babel/helper-validator-identifier": "^7.22.5", + "@eslint-community/eslint-utils": "^4.4.0", + "ci-info": "^3.8.0", "clean-regexp": "^1.0.0", - "esquery": "^1.4.0", + "esquery": "^1.5.0", "indent-string": "^4.0.0", - "is-builtin-module": "^3.2.0", + "is-builtin-module": "^3.2.1", "jsesc": "^3.0.2", "lodash": "^4.17.21", "pluralize": "^8.0.0", "read-pkg-up": "^7.0.1", - "regexp-tree": "^0.1.24", - "regjsparser": "^0.9.1", - "safe-regex": "^2.1.1", - "semver": "^7.3.8", + "regexp-tree": "^0.1.27", + "regjsparser": "^0.10.0", + "semver": "^7.5.4", "strip-indent": "^3.0.0" } }, "eslint-scope": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", - "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "version": "7.2.2", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", + "integrity": "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg==", "dev": true, "requires": { "esrecurse": "^4.3.0", - "estraverse": "^4.1.1" - } - }, - "eslint-utils": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-3.0.0.tgz", - "integrity": "sha512-uuQC43IGctw68pJA1RgbQS8/NP7rch6Cwd4j3ZBtgo4/8Flj4eGE7ZYSZRN3iq5pVUv6GPdW5Z1RFleo84uLDA==", - "dev": true, - "requires": { - "eslint-visitor-keys": "^2.0.0" - }, - "dependencies": { - "eslint-visitor-keys": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-2.1.0.tgz", - "integrity": "sha512-0rSmRBzXgDzIsD6mGdJgevzgezI534Cer5L/vyMX0kHzT/jiB43jRhd9YUlMGYLQy2zprNmoT8qasCGtY+QaKw==", - "dev": true - } + "estraverse": "^5.2.0" } }, "eslint-visitor-keys": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz", - "integrity": "sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA==", + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", "dev": true }, "espree": { - "version": "9.4.1", - "resolved": "https://registry.npmjs.org/espree/-/espree-9.4.1.tgz", - "integrity": "sha512-XwctdmTO6SIvCzd9810yyNzIrOrqNYV9Koizx4C/mRhf9uq0o4yHoCEU/670pOxOL/MSraektvSAji79kX90Vg==", + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", "dev": true, "requires": { - "acorn": "^8.8.0", + "acorn": "^8.9.0", "acorn-jsx": "^5.3.2", - "eslint-visitor-keys": "^3.3.0" + "eslint-visitor-keys": "^3.4.1" } }, "esquery": { @@ -7985,14 +6042,6 @@ "dev": true, "requires": { "estraverse": "^5.1.0" - }, - "dependencies": { - "estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", - "dev": true - } } }, "esrecurse": { @@ -8002,20 +6051,12 @@ "dev": true, "requires": { "estraverse": "^5.2.0" - }, - "dependencies": { - "estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", - "dev": true - } } }, "estraverse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", - "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", "dev": true }, "esutils": { @@ -8036,16 +6077,6 @@ "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", "dev": true }, - "evp_bytestokey": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz", - "integrity": "sha512-/f2Go4TognH/KvCISP7OUsHn85hT9nUkxxA9BEWxFn+Oj9o8ZNLm/40hdlgSLyuOimsrTKLUMEorQexp/aPQeA==", - "dev": true, - "requires": { - "md5.js": "^1.3.4", - "safe-buffer": "^5.1.1" - } - }, "fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -8053,9 +6084,9 @@ "dev": true }, "fast-glob": { - "version": "3.2.12", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", - "integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", + "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", "dev": true, "requires": { "@nodelib/fs.stat": "^2.0.2", @@ -8088,12 +6119,6 @@ "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", "dev": true }, - "fastest-levenshtein": { - "version": "1.0.16", - "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", - "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", - "dev": true - }, "fastq": { "version": "1.15.0", "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", @@ -8121,12 +6146,6 @@ "to-regex-range": "^5.0.1" } }, - "filter-obj": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/filter-obj/-/filter-obj-2.0.2.tgz", - "integrity": "sha512-lO3ttPjHZRfjMcxWKb1j1eDhTFsu4meeR3lnMcnBFhk6RuLhvEiuALu2TlfL310ph4lCYYwgF/ElIjdP739tdg==", - "dev": true - }, "find-up": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", @@ -8169,9 +6188,9 @@ } }, "fs-extra": { - "version": "11.1.0", - "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.1.0.tgz", - "integrity": "sha512-0rcTq621PD5jM/e0a3EJoGC/1TC5ZBCERW82LQuwfGnCa1V8w7dpYH1yNu+SLb6E5dkeCBzKEyLGlFrnr+dUyw==", + "version": "11.1.1", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.1.1.tgz", + "integrity": "sha512-MGIE4HOvQCeUCzmlHs0vXpih4ysz4wg9qiSAu6cd42lVwPbTM1TjV7RusoyQqMmk/95gdQZX72u+YW+c3eEpFQ==", "dev": true, "requires": { "graceful-fs": "^4.2.0", @@ -8199,15 +6218,15 @@ "dev": true }, "function.prototype.name": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.5.tgz", - "integrity": "sha512-uN7m/BzVKQnCUF/iW8jYea67v++2u7m5UgENbHRtdDVclOUP+FMPlCNdmk0h/ysGyo2tavMJEDqJAkJdRa1vMA==", + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.6.tgz", + "integrity": "sha512-Z5kx79swU5P27WEayXM1tBi5Ze/lbIyiNgU3qyXUOf9b2rgXYyF9Dy9Cx+IQv/Lc8WCG6L82zwUPpSS9hGehIg==", "dev": true, "requires": { "call-bind": "^1.0.2", - "define-properties": "^1.1.3", - "es-abstract": "^1.19.0", - "functions-have-names": "^1.2.2" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "functions-have-names": "^1.2.3" } }, "functions-have-names": { @@ -8239,13 +6258,14 @@ "dev": true }, "get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", + "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", "dev": true, "requires": { "function-bind": "^1.1.1", "has": "^1.0.3", + "has-proto": "^1.0.1", "has-symbols": "^1.0.3" } }, @@ -8282,16 +6302,10 @@ "is-glob": "^4.0.3" } }, - "glob-to-regexp": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", - "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", - "dev": true - }, "globals": { - "version": "13.20.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-13.20.0.tgz", - "integrity": "sha512-Qg5QtVkCy/kv3FUSlu4ukeZDVf9ee0iXLAUYX13gbR17bnejFTzr4iS9bY7kwCf1NztRNm1t91fjOiyx4CSwPQ==", + "version": "13.23.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.23.0.tgz", + "integrity": "sha512-XAmF0RjlrjY23MA51q3HltdlGxUpXPvg0GioKiD9X6HD28iMjo2dKC8Vqwm7lne4GNr78+RHTfliktR6ZH09wA==", "dev": true, "requires": { "type-fest": "^0.20.2" @@ -8335,21 +6349,12 @@ "integrity": "sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==", "dev": true }, - "grapheme-splitter": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz", - "integrity": "sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==", + "graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", "dev": true }, - "gzip-size": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz", - "integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==", - "dev": true, - "requires": { - "duplexer": "^0.1.2" - } - }, "has": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", @@ -8387,95 +6392,38 @@ "dev": true }, "has-symbols": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", - "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", - "dev": true - }, - "has-tostringtag": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.0.tgz", - "integrity": "sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ==", - "dev": true, - "requires": { - "has-symbols": "^1.0.2" - } - }, - "has-unicode": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", - "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", - "dev": true - }, - "hash-base": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/hash-base/-/hash-base-3.1.0.tgz", - "integrity": "sha512-1nmYp/rhMDiE7AYkDw+lLwlAzz0AntGIe51F3RfFfEqyQ3feY2eI/NcwC6umIQVOASPMsWJLJScWKSSvzL9IVA==", - "dev": true, - "requires": { - "inherits": "^2.0.4", - "readable-stream": "^3.6.0", - "safe-buffer": "^5.2.0" - }, - "dependencies": { - "readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - } - }, - "safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true - } - } + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", + "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "dev": true }, - "hash.js": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/hash.js/-/hash.js-1.1.7.tgz", - "integrity": "sha512-taOaskGt4z4SOANNseOviYDvjEJinIkRgmp7LbKP2YTTmVxWBl87s/uzK9r+44BclBSp2X7K1hqeNfz9JbBeXA==", + "has-tostringtag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.0.tgz", + "integrity": "sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ==", "dev": true, "requires": { - "inherits": "^2.0.3", - "minimalistic-assert": "^1.0.1" + "has-symbols": "^1.0.2" } }, + "has-unicode": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", + "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", + "dev": true + }, "he": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz", "integrity": "sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==", "dev": true }, - "hmac-drbg": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/hmac-drbg/-/hmac-drbg-1.0.1.tgz", - "integrity": "sha512-Tti3gMqLdZfhOQY1Mzf/AanLiqh1WTiJgEj26ZuYQ9fbkLomzGchCws4FyrSd4VkpBfiNhaE1On+lOz894jvXg==", - "dev": true, - "requires": { - "hash.js": "^1.0.3", - "minimalistic-assert": "^1.0.0", - "minimalistic-crypto-utils": "^1.0.1" - } - }, "hosted-git-info": { "version": "2.8.9", "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", "dev": true }, - "https-browserify": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/https-browserify/-/https-browserify-1.0.0.tgz", - "integrity": "sha512-J+FkSdyD+0mA0N+81tMotaRMfSL9SGi+xpD3T6YApKsc3bGSXJlfXri3VyFOeYkfLRQisDk1W+jIFFKBeUBbBg==", - "dev": true - }, "ieee754": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz", @@ -8504,15 +6452,11 @@ "resolve-from": "^4.0.0" } }, - "import-local": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz", - "integrity": "sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg==", - "dev": true, - "requires": { - "pkg-dir": "^4.2.0", - "resolve-cwd": "^3.0.0" - } + "import-meta-resolve": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/import-meta-resolve/-/import-meta-resolve-3.0.0.tgz", + "integrity": "sha512-4IwhLhNNA8yy445rPjD/lWh++7hMDOml2eHtd58eG7h+qK3EryMuuRbsHGPikCoAgIkkDnckKfWSk2iDla/ejg==", + "dev": true }, "imurmurhash": { "version": "0.1.4", @@ -8553,22 +6497,6 @@ "side-channel": "^1.0.4" } }, - "interpret": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/interpret/-/interpret-3.1.1.tgz", - "integrity": "sha512-6xwYfHbajpoF0xLW+iwLkhwgvLoZDfjYfoFNu8ftMoXINzwuymNLd9u/KmwtdT2GbR+/Cz66otEGEVVUHX9QLQ==", - "dev": true - }, - "is-arguments": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/is-arguments/-/is-arguments-1.1.1.tgz", - "integrity": "sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA==", - "dev": true, - "requires": { - "call-bind": "^1.0.2", - "has-tostringtag": "^1.0.0" - } - }, "is-array-buffer": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.2.tgz", @@ -8630,9 +6558,9 @@ "dev": true }, "is-core-module": { - "version": "2.11.0", - "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.11.0.tgz", - "integrity": "sha512-RRjxlvLDkD1YJwDbroBHMb+cukurkDWNyHx7D3oNB5x9rb5ogcksMC5wHCadcXoo67gVr/+3GFySh3134zi6rw==", + "version": "2.13.0", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.0.tgz", + "integrity": "sha512-Z7dk6Qo8pOCp3l4tsX2C5ZVas4V+UxwQodwZhLopL91TX8UyyHEXafPcyoeeWuLrwzHcr3igO78wNLwHJHsMCQ==", "dev": true, "requires": { "has": "^1.0.3" @@ -8659,15 +6587,6 @@ "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", "dev": true }, - "is-generator-function": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/is-generator-function/-/is-generator-function-1.0.10.tgz", - "integrity": "sha512-jsEjy9l3yiXEQ+PsXdmBwEPcOxaXWLspKdplFUVI9vq1iZgIekeC0L167qeu86czQaxed3q/Uzuw0swL0irL8A==", - "dev": true, - "requires": { - "has-tostringtag": "^1.0.0" - } - }, "is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -8677,16 +6596,6 @@ "is-extglob": "^2.1.1" } }, - "is-nan": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/is-nan/-/is-nan-1.3.2.tgz", - "integrity": "sha512-E+zBKpQ2t6MEo1VsonYmluk9NxGrbzpeeLC2xIViuO2EjU2xsXsBPwTr3Ykv9l08UYEVEdWeRZNouaZqF6RN0w==", - "dev": true, - "requires": { - "call-bind": "^1.0.0", - "define-properties": "^1.1.3" - } - }, "is-negative-zero": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.2.tgz", @@ -8720,15 +6629,6 @@ "integrity": "sha512-YWnfyRwxL/+SsrWYfOpUtz5b3YD+nyfkHvjbcanzk8zgyO4ASD67uVMRt8k5bM4lLMDnXfriRhOpemw+NfT1eA==", "dev": true }, - "is-plain-object": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", - "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", - "dev": true, - "requires": { - "isobject": "^3.0.1" - } - }, "is-regex": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.1.4.tgz", @@ -8767,16 +6667,12 @@ } }, "is-typed-array": { - "version": "1.1.10", - "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.10.tgz", - "integrity": "sha512-PJqgEHiWZvMpaFZ3uTc8kHPM4+4ADTlDniuQL7cU/UDA0Ql7F70yGfHph3cLNe+c9toaigv+DFzTJKhc2CtO6A==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.12.tgz", + "integrity": "sha512-Z14TF2JNG8Lss5/HMqt0//T9JeHXttXy5pH/DBU4vi98ozO2btxzq9MwYDZYnKwU8nRsz/+GVFVRDq3DkVuSPg==", "dev": true, "requires": { - "available-typed-arrays": "^1.0.5", - "call-bind": "^1.0.2", - "for-each": "^0.3.3", - "gopd": "^1.0.1", - "has-tostringtag": "^1.0.0" + "which-typed-array": "^1.1.11" } }, "is-unicode-supported": { @@ -8806,40 +6702,6 @@ "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", "dev": true }, - "isobject": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", - "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", - "dev": true - }, - "jest-worker": { - "version": "27.5.1", - "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-27.5.1.tgz", - "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", - "dev": true, - "requires": { - "@types/node": "*", - "merge-stream": "^2.0.0", - "supports-color": "^8.0.0" - }, - "dependencies": { - "supports-color": { - "version": "8.1.1", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", - "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", - "dev": true, - "requires": { - "has-flag": "^4.0.0" - } - } - } - }, - "js-sdsl": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/js-sdsl/-/js-sdsl-4.3.0.tgz", - "integrity": "sha512-mifzlm2+5nZ+lEcLJMoBK0/IH/bDg8XnJfd/Wq6IP+xoCjLZsTOnV2QpxlVbX9bMnkl5PdEjNtBJ9Cj1NjifhQ==", - "dev": true - }, "js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -8856,9 +6718,9 @@ } }, "jsdoc-type-pratt-parser": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-3.1.0.tgz", - "integrity": "sha512-MgtD0ZiCDk9B+eI73BextfRrVQl0oyzRG8B2BjORts6jbunj4ScKPcyXGTbB6eXL4y9TzxCm6hyeLq/2ASzNdw==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-4.0.0.tgz", + "integrity": "sha512-YtOli5Cmzy3q4dP26GraSOeAhqecewG04hoO8DY56CH4KJ9Fvv5qKWUCCo3HZob7esJQHCv6/+bnTy72xZZaVQ==", "dev": true }, "jsesc": { @@ -8916,12 +6778,6 @@ "setimmediate": "^1.0.5" } }, - "kind-of": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", - "integrity": "sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==", - "dev": true - }, "levn": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", @@ -8947,31 +6803,6 @@ "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", "dev": true }, - "loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", - "dev": true - }, - "loader-utils": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", - "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", - "dev": true, - "requires": { - "big.js": "^5.2.2", - "emojis-list": "^3.0.0", - "json5": "^2.1.2" - }, - "dependencies": { - "json5": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", - "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", - "dev": true - } - } - }, "locate-path": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", @@ -9012,23 +6843,6 @@ "yallist": "^4.0.0" } }, - "md5.js": { - "version": "1.3.5", - "resolved": "https://registry.npmjs.org/md5.js/-/md5.js-1.3.5.tgz", - "integrity": "sha512-xitP+WxNPcTTOgnTJcrhM0xvdPepipPSf3I8EIpGKeFLjt3PlJLIDG3u8EX53ZIubkb+5U2+3rELYpEhHhzdkg==", - "dev": true, - "requires": { - "hash-base": "^3.0.0", - "inherits": "^2.0.1", - "safe-buffer": "^5.1.2" - } - }, - "merge-stream": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", - "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", - "dev": true - }, "merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -9045,57 +6859,12 @@ "picomatch": "^2.3.1" } }, - "miller-rabin": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/miller-rabin/-/miller-rabin-4.0.1.tgz", - "integrity": "sha512-115fLhvZVqWwHPbClyntxEVfVDfl9DLLTuJvq3g2O/Oxi8AiNouAHvDSzHS0viUJc+V5vm3eq91Xwqn9dp4jRA==", - "dev": true, - "requires": { - "bn.js": "^4.0.0", - "brorand": "^1.0.1" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, - "mime-db": { - "version": "1.52.0", - "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", - "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", - "dev": true - }, - "mime-types": { - "version": "2.1.35", - "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", - "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", - "dev": true, - "requires": { - "mime-db": "1.52.0" - } - }, "min-indent": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz", "integrity": "sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg==", "dev": true }, - "minimalistic-assert": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", - "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", - "dev": true - }, - "minimalistic-crypto-utils": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/minimalistic-crypto-utils/-/minimalistic-crypto-utils-1.0.1.tgz", - "integrity": "sha512-JIYlbt6g8i5jKfJ3xz7rF0LXmv2TkDxBLUkiBeZ7bAx4GnnNMr8xFpGnOxn6GhTEHx3SjRrZEoU+j04prX1ktg==", - "dev": true - }, "minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -9211,12 +6980,6 @@ } } }, - "mrmime": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz", - "integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==", - "dev": true - }, "ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", @@ -9235,92 +6998,6 @@ "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", "dev": true }, - "natural-compare-lite": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/natural-compare-lite/-/natural-compare-lite-1.4.0.tgz", - "integrity": "sha512-Tj+HTDSJJKaZnfiuw+iaF9skdPpTo2GtEly5JHnWV/hfv2Qj/9RKsGISQtLh2ox3l5EAGw487hnBee0sIJ6v2g==", - "dev": true - }, - "neo-async": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", - "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", - "dev": true - }, - "node-polyfill-webpack-plugin": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/node-polyfill-webpack-plugin/-/node-polyfill-webpack-plugin-2.0.1.tgz", - "integrity": "sha512-ZUMiCnZkP1LF0Th2caY6J/eKKoA0TefpoVa68m/LQU1I/mE8rGt4fNYGgNuCcK+aG8P8P43nbeJ2RqJMOL/Y1A==", - "dev": true, - "requires": { - "assert": "^2.0.0", - "browserify-zlib": "^0.2.0", - "buffer": "^6.0.3", - "console-browserify": "^1.2.0", - "constants-browserify": "^1.0.0", - "crypto-browserify": "^3.12.0", - "domain-browser": "^4.22.0", - "events": "^3.3.0", - "filter-obj": "^2.0.2", - "https-browserify": "^1.0.0", - "os-browserify": "^0.3.0", - "path-browserify": "^1.0.1", - "process": "^0.11.10", - "punycode": "^2.1.1", - "querystring-es3": "^0.2.1", - "readable-stream": "^4.0.0", - "stream-browserify": "^3.0.0", - "stream-http": "^3.2.0", - "string_decoder": "^1.3.0", - "timers-browserify": "^2.0.12", - "tty-browserify": "^0.0.1", - "type-fest": "^2.14.0", - "url": "^0.11.0", - "util": "^0.12.4", - "vm-browserify": "^1.1.2" - }, - "dependencies": { - "readable-stream": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz", - "integrity": "sha512-MuEnA0lbSi7JS8XM+WNJlWZkHAAdm7gETHdFK//Q/mChGyj2akEFtdLZh32jSdkWGbRwCW9pn6g3LWDdDeZnBQ==", - "dev": true, - "requires": { - "abort-controller": "^3.0.0", - "buffer": "^6.0.3", - "events": "^3.3.0", - "process": "^0.11.10" - } - }, - "safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true - }, - "string_decoder": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", - "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", - "dev": true, - "requires": { - "safe-buffer": "~5.2.0" - } - }, - "type-fest": { - "version": "2.19.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", - "integrity": "sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==", - "dev": true - } - } - }, - "node-releases": { - "version": "2.0.10", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.10.tgz", - "integrity": "sha512-5GFldHPXVG/YZmFzJvKK2zDSzPKhEp0+ZR5SVaoSag9fsL5YgHbUHDfnG5494ISANDcK4KwPXAx2xqVEydmd7w==", - "dev": true - }, "normalize-package-data": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/normalize-package-data/-/normalize-package-data-2.5.0.tgz", @@ -9334,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } @@ -9365,16 +7042,6 @@ "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", "dev": true }, - "object-is": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/object-is/-/object-is-1.1.5.tgz", - "integrity": "sha512-3cyDsyHgtmi7I7DfSSI2LDp6SK2lwvtbg0p0R1e0RvTqF5ceGx+K2dfSjm1bKDMVCFEDAQvy+o8c6a7VujOddw==", - "dev": true, - "requires": { - "call-bind": "^1.0.2", - "define-properties": "^1.1.3" - } - }, "object-keys": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", @@ -9393,6 +7060,29 @@ "object-keys": "^1.1.1" } }, + "object.fromentries": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/object.fromentries/-/object.fromentries-2.0.7.tgz", + "integrity": "sha512-UPbPHML6sL8PI/mOqPwsH4G6iyXcCGzLin8KvEPenOZN5lpCNBZZQ+V62vdjB1mQHrmqGQt5/OJzemUA+KJmEA==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" + } + }, + "object.groupby": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/object.groupby/-/object.groupby-1.0.1.tgz", + "integrity": "sha512-HqaQtqLnp/8Bn4GL16cj+CUYbnpe1bh0TtEaWvybszDG4tgxCJuRpV8VGuvNaI1fAnI4lUJzDG55MXcOH4JZcQ==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1", + "get-intrinsic": "^1.2.1" + } + }, "object.values": { "version": "1.1.6", "resolved": "https://registry.npmjs.org/object.values/-/object.values-1.1.6.tgz", @@ -9413,32 +7103,20 @@ "wrappy": "1" } }, - "opener": { - "version": "1.5.2", - "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", - "integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==", - "dev": true - }, "optionator": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", - "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", + "integrity": "sha512-JjCoypp+jKn1ttEFExxhetCKeJt9zhAgAve5FXHixTvFDW/5aEktX9bufBKLRRMdU7bNtpLfcGu94B3cdEJgjg==", "dev": true, "requires": { + "@aashutoshrathi/word-wrap": "^1.2.3", "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", - "type-check": "^0.4.0", - "word-wrap": "^1.2.3" + "type-check": "^0.4.0" } }, - "os-browserify": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/os-browserify/-/os-browserify-0.3.0.tgz", - "integrity": "sha512-gjcpUc3clBf9+210TRaDWbf+rZZZEshZ+DlXMRCeAjp0xhTrnQsKHypIy1J3d5hKdUzj69t708EHtU8P6bUn0A==", - "dev": true - }, "p-limit": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", @@ -9478,19 +7156,6 @@ "callsites": "^3.0.0" } }, - "parse-asn1": { - "version": "5.1.6", - "resolved": "https://registry.npmjs.org/parse-asn1/-/parse-asn1-5.1.6.tgz", - "integrity": "sha512-RnZRo1EPU6JBnra2vGHj0yhp6ebyjBZpmUCLHWiFhxlzvBCCpAuZ7elsBp1PVAbQN0/04VD/19rfzlBSwLstMw==", - "dev": true, - "requires": { - "asn1.js": "^5.2.0", - "browserify-aes": "^1.0.0", - "evp_bytestokey": "^1.0.0", - "pbkdf2": "^3.0.3", - "safe-buffer": "^5.1.1" - } - }, "parse-json": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-5.2.0.tgz", @@ -9503,12 +7168,6 @@ "lines-and-columns": "^1.1.6" } }, - "path-browserify": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz", - "integrity": "sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g==", - "dev": true - }, "path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -9537,80 +7196,13 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", - "dev": true - }, - "pbkdf2": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/pbkdf2/-/pbkdf2-3.1.2.tgz", - "integrity": "sha512-iuh7L6jA7JEGu2WxDwtQP1ddOpaJNC4KlDEFfdQajSGgGPNi4OyDc2R7QnbY2bR9QjBVGwgvTdNJZoE7RaxUMA==", - "dev": true, - "requires": { - "create-hash": "^1.1.2", - "create-hmac": "^1.1.4", - "ripemd160": "^2.0.1", - "safe-buffer": "^5.0.1", - "sha.js": "^2.4.8" - } - }, - "picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", - "dev": true - }, - "picomatch": { - "version": "2.3.1", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", - "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "dev": true - }, - "pkg-dir": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", - "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", - "dev": true, - "requires": { - "find-up": "^4.0.0" - }, - "dependencies": { - "find-up": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", - "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", - "dev": true, - "requires": { - "locate-path": "^5.0.0", - "path-exists": "^4.0.0" - } - }, - "locate-path": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", - "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", - "dev": true, - "requires": { - "p-locate": "^4.1.0" - } - }, - "p-limit": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", - "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", - "dev": true, - "requires": { - "p-try": "^2.0.0" - } - }, - "p-locate": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", - "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", - "dev": true, - "requires": { - "p-limit": "^2.2.0" - } - } - } + "dev": true + }, + "picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true }, "pluralize": { "version": "8.0.0", @@ -9625,9 +7217,9 @@ "dev": true }, "prettier": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.0.tgz", - "integrity": "sha512-zBf5eHpwHOGPC47h0zrPyNn+eAEIdEzfywMoYn2XPi0P44Zp0tSq64rq0xAREh4auw2cJZHo9QUob+NqCQky4g==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", + "integrity": "sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==", "dev": true }, "process": { @@ -9642,46 +7234,12 @@ "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", "dev": true }, - "public-encrypt": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/public-encrypt/-/public-encrypt-4.0.3.tgz", - "integrity": "sha512-zVpa8oKZSz5bTMTFClc1fQOnyyEzpl5ozpi1B5YcvBrdohMjH2rfsBtyXcuNuwjsDIXmBYlF2N5FlJYhR29t8Q==", - "dev": true, - "requires": { - "bn.js": "^4.1.0", - "browserify-rsa": "^4.0.0", - "create-hash": "^1.1.0", - "parse-asn1": "^5.0.0", - "randombytes": "^2.0.1", - "safe-buffer": "^5.1.2" - }, - "dependencies": { - "bn.js": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/bn.js/-/bn.js-4.12.0.tgz", - "integrity": "sha512-c98Bf3tPniI+scsdk237ku1Dc3ujXQTSgyiPUDEOe7tRkhrqridvh8klBv0HCEso1OLOYcHuCv/cS6DNxKH+ZA==", - "dev": true - } - } - }, "punycode": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", "dev": true }, - "querystring": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/querystring/-/querystring-0.2.0.tgz", - "integrity": "sha512-X/xY82scca2tau62i9mDyU9K+I+djTMUsvwf7xnUX5GLvVzgJybOJf4Y6o9Zx3oJK/LSXg5tTZBjwzqVPaPO2g==", - "dev": true - }, - "querystring-es3": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/querystring-es3/-/querystring-es3-0.2.1.tgz", - "integrity": "sha512-773xhDQnZBMFobEiztv8LIl70ch5MSF/jUQVlhwFyBILqq96anmoctVIYz+ZRp0qbCKATTn6ev02M3r7Ga5vqA==", - "dev": true - }, "queue-microtask": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", @@ -9697,16 +7255,6 @@ "safe-buffer": "^5.1.0" } }, - "randomfill": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/randomfill/-/randomfill-1.0.4.tgz", - "integrity": "sha512-87lcbR8+MhcWcUiQ+9e+Rwx8MyR2P7qnt15ynUlbm3TU/fjbgz4GsvfSUDTemtCCtVCqb4ZcEFlyPNTh9bBTLw==", - "dev": true, - "requires": { - "randombytes": "^2.0.5", - "safe-buffer": "^5.1.0" - } - }, "read-pkg": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-5.2.0.tgz", @@ -9807,42 +7355,27 @@ "picomatch": "^2.2.1" } }, - "rechoir": { - "version": "0.8.0", - "resolved": "https://registry.npmjs.org/rechoir/-/rechoir-0.8.0.tgz", - "integrity": "sha512-/vxpCXddiX8NGfGO/mTafwjq4aFa/71pvamip0++IQk3zG8cbCj0fifNPrjjF1XMXUne91jL9OoxmdykoEtifQ==", - "dev": true, - "requires": { - "resolve": "^1.20.0" - } - }, "regexp-tree": { - "version": "0.1.24", - "resolved": "https://registry.npmjs.org/regexp-tree/-/regexp-tree-0.1.24.tgz", - "integrity": "sha512-s2aEVuLhvnVJW6s/iPgEGK6R+/xngd2jNQ+xy4bXNDKxZKJH6jpPHY6kVeVv1IeLCHgswRj+Kl3ELaDjG6V1iw==", + "version": "0.1.27", + "resolved": "https://registry.npmjs.org/regexp-tree/-/regexp-tree-0.1.27.tgz", + "integrity": "sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==", "dev": true }, "regexp.prototype.flags": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.4.3.tgz", - "integrity": "sha512-fjggEOO3slI6Wvgjwflkc4NFRCTZAu5CnNfBd5qOMYhWdn67nJBBu34/TkD++eeFmd8C9r9jfXJ27+nSiRkSUA==", + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.1.tgz", + "integrity": "sha512-sy6TXMN+hnP/wMy+ISxg3krXx7BAtWVO4UouuCN/ziM9UEne0euamVNafDfvC83bRNr95y0V5iijeDQFUNpvrg==", "dev": true, "requires": { "call-bind": "^1.0.2", - "define-properties": "^1.1.3", - "functions-have-names": "^1.2.2" + "define-properties": "^1.2.0", + "set-function-name": "^2.0.0" } }, - "regexpp": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.2.0.tgz", - "integrity": "sha512-pq2bWo9mVD43nbts2wGv17XLiNLya+GklZ8kaDLV2Z08gDCsGpnKn9BFMepvWuHCbyVvY7J5o5+BVvoQbmlJLg==", - "dev": true - }, "regjsparser": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/regjsparser/-/regjsparser-0.9.1.tgz", - "integrity": "sha512-dQUtn90WanSNl+7mQKcXAgZxvUe7Z0SqXlgzv0za4LwiUhyzBC58yQO3liFoUgu8GiJVInAhJjkj1N0EtQ5nkQ==", + "version": "0.10.0", + "resolved": "https://registry.npmjs.org/regjsparser/-/regjsparser-0.10.0.tgz", + "integrity": "sha512-qx+xQGZVsy55CH0a1hiVwHmqjLryfh7wQyF5HO07XJ9f7dQMY/gPQHhlyDkIzJKC+x2fUCpCcUODUUUFrm7SHA==", "dev": true, "requires": { "jsesc": "~0.5.0" @@ -9873,23 +7406,6 @@ "supports-preserve-symlinks-flag": "^1.0.0" } }, - "resolve-cwd": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", - "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", - "dev": true, - "requires": { - "resolve-from": "^5.0.0" - }, - "dependencies": { - "resolve-from": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", - "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", - "dev": true - } - } - }, "resolve-from": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", @@ -9911,16 +7427,6 @@ "glob": "^7.1.3" } }, - "ripemd160": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/ripemd160/-/ripemd160-2.0.2.tgz", - "integrity": "sha512-ii4iagi25WusVoiC4B4lq7pbXfAp3D9v5CwfkY33vffw2+pkDjY1D8GaN7spsxvCSx8dkPqOZCEZyfxcmJG2IA==", - "dev": true, - "requires": { - "hash-base": "^3.0.0", - "inherits": "^2.0.1" - } - }, "run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -9930,21 +7436,32 @@ "queue-microtask": "^1.2.2" } }, + "safe-array-concat": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/safe-array-concat/-/safe-array-concat-1.0.1.tgz", + "integrity": "sha512-6XbUAseYE2KtOuGueyeobCySj9L4+66Tn6KQMOPQJrAJEowYKW/YR/MGJZl7FdydUdaFu4LYyDZjxf4/Nmo23Q==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "get-intrinsic": "^1.2.1", + "has-symbols": "^1.0.3", + "isarray": "^2.0.5" + }, + "dependencies": { + "isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true + } + } + }, "safe-buffer": { "version": "5.1.2", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", "dev": true }, - "safe-regex": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/safe-regex/-/safe-regex-2.1.1.tgz", - "integrity": "sha512-rx+x8AMzKb5Q5lQ95Zoi6ZbJqwCLkqi3XuJXp5P3rT8OEc6sZCJG5AE5dU3lsgRr/F4Bs31jSlVN+j5KrsGu9A==", - "dev": true, - "requires": { - "regexp-tree": "~0.1.1" - } - }, "safe-regex-test": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/safe-regex-test/-/safe-regex-test-1.0.0.tgz", @@ -9956,72 +7473,38 @@ "is-regex": "^1.1.4" } }, - "safer-buffer": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", - "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "dev": true - }, - "schema-utils": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.1.1.tgz", - "integrity": "sha512-Y5PQxS4ITlC+EahLuXaY86TXfR7Dc5lw294alXOq86JAHCihAIZfqv8nNCWvaEJvaC51uN9hbLGeV0cFBdH+Fw==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" - } - }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" } }, - "serialize-javascript": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", - "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", - "dev": true, - "requires": { - "randombytes": "^2.1.0" - } - }, "set-blocking": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", "dev": true }, + "set-function-name": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/set-function-name/-/set-function-name-2.0.1.tgz", + "integrity": "sha512-tMNCiqYVkXIZgc2Hnoy2IvC/f8ezc5koaRFkCjrpWzGpCd3qbZXPzVy9MAZzK1ch/X0jvSkojys3oqJN0qCmdA==", + "dev": true, + "requires": { + "define-data-property": "^1.0.1", + "functions-have-names": "^1.2.3", + "has-property-descriptors": "^1.0.0" + } + }, "setimmediate": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==", "dev": true }, - "sha.js": { - "version": "2.4.11", - "resolved": "https://registry.npmjs.org/sha.js/-/sha.js-2.4.11.tgz", - "integrity": "sha512-QMEp5B7cftE7APOjk5Y6xgrbWu+WkLVQwk8JNjZ8nKRciZaByEW6MubieAiToS7+dwvrjGhH8jRXz3MVd0AYqQ==", - "dev": true, - "requires": { - "inherits": "^2.0.1", - "safe-buffer": "^5.0.1" - } - }, - "shallow-clone": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/shallow-clone/-/shallow-clone-3.0.1.tgz", - "integrity": "sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA==", - "dev": true, - "requires": { - "kind-of": "^6.0.2" - } - }, "shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -10054,39 +7537,12 @@ "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", "dev": true }, - "sirv": { - "version": "1.0.19", - "resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz", - "integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==", - "dev": true, - "requires": { - "@polka/url": "^1.0.0-next.20", - "mrmime": "^1.0.0", - "totalist": "^1.0.0" - } - }, "slash": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", "dev": true }, - "source-map": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", - "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", - "dev": true - }, - "source-map-support": { - "version": "0.5.21", - "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", - "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", - "dev": true, - "requires": { - "buffer-from": "^1.0.0", - "source-map": "^0.6.0" - } - }, "spdx-correct": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/spdx-correct/-/spdx-correct-3.1.1.tgz", @@ -10119,54 +7575,6 @@ "integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==", "dev": true }, - "stream-browserify": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/stream-browserify/-/stream-browserify-3.0.0.tgz", - "integrity": "sha512-H73RAHsVBapbim0tU2JwwOiXUj+fikfiaoYAKHF3VJfA0pe2BCzkhAHBlLG6REzE+2WNZcxOXjK7lkso+9euLA==", - "dev": true, - "requires": { - "inherits": "~2.0.4", - "readable-stream": "^3.5.0" - }, - "dependencies": { - "readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - } - } - } - }, - "stream-http": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/stream-http/-/stream-http-3.2.0.tgz", - "integrity": "sha512-Oq1bLqisTyK3TSCXpPbT4sdeYNdmyZJv1LxpEm2vu1ZhK89kSE5YXwZc3cWk0MagGaKriBh9mCFbVGtO+vY29A==", - "dev": true, - "requires": { - "builtin-status-codes": "^3.0.0", - "inherits": "^2.0.4", - "readable-stream": "^3.6.0", - "xtend": "^4.0.2" - }, - "dependencies": { - "readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - } - } - } - }, "string_decoder": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", @@ -10187,26 +7595,37 @@ "strip-ansi": "^6.0.1" } }, + "string.prototype.trim": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/string.prototype.trim/-/string.prototype.trim-1.2.8.tgz", + "integrity": "sha512-lfjY4HcixfQXOfaqCvcBuOIapyaroTXhbkfJN3gcB1OtyupngWK4sEET9Knd0cXd28kTUqu/kHoV4HKSJdnjiQ==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" + } + }, "string.prototype.trimend": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.6.tgz", - "integrity": "sha512-JySq+4mrPf9EsDBEDYMOb/lM7XQLulwg5R/m1r0PXEFqrV0qHvl58sdTilSXtKOflCsK2E8jxf+GKC0T07RWwQ==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.7.tgz", + "integrity": "sha512-Ni79DqeB72ZFq1uH/L6zJ+DKZTkOtPIHovb3YZHQViE+HDouuU4mBrLOLDn5Dde3RF8qw5qVETEjhu9locMLvA==", "dev": true, "requires": { "call-bind": "^1.0.2", - "define-properties": "^1.1.4", - "es-abstract": "^1.20.4" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" } }, "string.prototype.trimstart": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.6.tgz", - "integrity": "sha512-omqjMDaY92pbn5HOX7f9IccLA+U1tA9GvtU4JrodiXFfYB7jPzzHpRzpglLAjtUV6bB557zwClJezTqnAiYnQA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.7.tgz", + "integrity": "sha512-NGhtDFu3jCEm7B4Fy0DpLewdJQOZcQ0rGbwQ/+stjnrp2i+rlKeCvos9hOIeCmqwratM47OBxY7uFZzjxHXmrg==", "dev": true, "requires": { "call-bind": "^1.0.2", - "define-properties": "^1.1.4", - "es-abstract": "^1.20.4" + "define-properties": "^1.2.0", + "es-abstract": "^1.22.1" } }, "strip-ansi": { @@ -10254,52 +7673,12 @@ "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", "dev": true }, - "tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", - "dev": true - }, - "terser": { - "version": "5.16.5", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.16.5.tgz", - "integrity": "sha512-qcwfg4+RZa3YvlFh0qjifnzBHjKGNbtDo9yivMqMFDy9Q6FSaQWSB/j1xKhsoUFJIqDOM3TsN6D5xbrMrFcHbg==", - "dev": true, - "requires": { - "@jridgewell/source-map": "^0.3.2", - "acorn": "^8.5.0", - "commander": "^2.20.0", - "source-map-support": "~0.5.20" - } - }, - "terser-webpack-plugin": { - "version": "5.3.6", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.6.tgz", - "integrity": "sha512-kfLFk+PoLUQIbLmB1+PZDMRSZS99Mp+/MHqDNmMA6tOItzRt+Npe3E+fsMs5mfcM0wCtrrdU387UnV+vnSffXQ==", - "dev": true, - "requires": { - "@jridgewell/trace-mapping": "^0.3.14", - "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.0", - "terser": "^5.14.1" - } - }, "text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", "dev": true }, - "timers-browserify": { - "version": "2.0.12", - "resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz", - "integrity": "sha512-9phl76Cqm6FhSX9Xe1ZUAMLtm1BLkKj2Qd5ApyWkXzsMRaA7dgr81kf4wJmQf/hAvg8EEyJxDo3du/0KlhPiKQ==", - "dev": true, - "requires": { - "setimmediate": "^1.0.4" - } - }, "to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -10309,23 +7688,12 @@ "is-number": "^7.0.0" } }, - "totalist": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz", - "integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==", - "dev": true - }, - "ts-loader": { - "version": "9.4.2", - "resolved": "https://registry.npmjs.org/ts-loader/-/ts-loader-9.4.2.tgz", - "integrity": "sha512-OmlC4WVmFv5I0PpaxYb+qGeGOdm5giHU7HwDDUjw59emP2UYMHy9fFSDcYgSNoH8sXcj4hGCSEhlDZ9ULeDraA==", + "ts-api-utils": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.0.3.tgz", + "integrity": "sha512-wNMeqtMz5NtwpT/UZGY5alT+VoKdSsOOP/kqHFcUW1P/VRhH2wJ48+DN2WwUliNbQ976ETwDL0Ifd2VVvgonvg==", "dev": true, - "requires": { - "chalk": "^4.1.0", - "enhanced-resolve": "^5.0.0", - "micromatch": "^4.0.0", - "semver": "^7.3.4" - } + "requires": {} }, "tsconfig-paths": { "version": "3.14.2", @@ -10339,27 +7707,6 @@ "strip-bom": "^3.0.0" } }, - "tslib": { - "version": "1.14.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", - "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", - "dev": true - }, - "tsutils": { - "version": "3.21.0", - "resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz", - "integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==", - "dev": true, - "requires": { - "tslib": "^1.8.1" - } - }, - "tty-browserify": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/tty-browserify/-/tty-browserify-0.0.1.tgz", - "integrity": "sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==", - "dev": true - }, "type-check": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", @@ -10375,6 +7722,42 @@ "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", "dev": true }, + "typed-array-buffer": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-buffer/-/typed-array-buffer-1.0.0.tgz", + "integrity": "sha512-Y8KTSIglk9OZEr8zywiIHG/kmQ7KWyjseXs1CbSo8vC42w7hg2HgYTxSWwP0+is7bWDc1H+Fo026CpHFwm8tkw==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "get-intrinsic": "^1.2.1", + "is-typed-array": "^1.1.10" + } + }, + "typed-array-byte-length": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-byte-length/-/typed-array-byte-length-1.0.0.tgz", + "integrity": "sha512-Or/+kvLxNpeQ9DtSydonMxCx+9ZXOswtwJn17SNLvhptaXYDJvkFFP5zbfU/uLmvnBJlI4yrnXRxpdWH/M5tNA==", + "dev": true, + "requires": { + "call-bind": "^1.0.2", + "for-each": "^0.3.3", + "has-proto": "^1.0.1", + "is-typed-array": "^1.1.10" + } + }, + "typed-array-byte-offset": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/typed-array-byte-offset/-/typed-array-byte-offset-1.0.0.tgz", + "integrity": "sha512-RD97prjEt9EL8YgAgpOkf3O4IF9lhJFr9g0htQkm0rchFp/Vx7LW5Q8fSXXub7BXAODyUQohRMyOc3faCPd0hg==", + "dev": true, + "requires": { + "available-typed-arrays": "^1.0.5", + "call-bind": "^1.0.2", + "for-each": "^0.3.3", + "has-proto": "^1.0.1", + "is-typed-array": "^1.1.10" + } + }, "typed-array-length": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/typed-array-length/-/typed-array-length-1.0.4.tgz", @@ -10387,9 +7770,9 @@ } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true }, "unbox-primitive": { @@ -10410,16 +7793,6 @@ "integrity": "sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==", "dev": true }, - "update-browserslist-db": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.10.tgz", - "integrity": "sha512-OztqDenkfFkbSG+tRxBeAnCVPckDBcvibKd35yDONx6OU8N7sqgwc7rCbkJ/WcYtVRZ4ba68d6byhC21GFh7sQ==", - "dev": true, - "requires": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" - } - }, "uri-js": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", @@ -10429,37 +7802,6 @@ "punycode": "^2.1.0" } }, - "url": { - "version": "0.11.0", - "resolved": "https://registry.npmjs.org/url/-/url-0.11.0.tgz", - "integrity": "sha512-kbailJa29QrtXnxgq+DdCEGlbTeYM2eJUxsz6vjZavrCYPMIFHMKQmSKYAIuUK2i7hgPm28a8piX5NTUtM/LKQ==", - "dev": true, - "requires": { - "punycode": "1.3.2", - "querystring": "0.2.0" - }, - "dependencies": { - "punycode": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-1.3.2.tgz", - "integrity": "sha512-RofWgt/7fL5wP1Y7fxE7/EmTLzQVnB0ycyibJ0OOHIlJqTNzglYFxVwETOcIoJqJmpDXJ9xImDv+Fq34F/d4Dw==", - "dev": true - } - } - }, - "util": { - "version": "0.12.5", - "resolved": "https://registry.npmjs.org/util/-/util-0.12.5.tgz", - "integrity": "sha512-kZf/K6hEIrWHI6XqOFUiiMa+79wE/D8Q+NCNAWclkyg3b4d2k7s0QGepNjiABc+aR3N1PAyHL7p6UcLY6LmrnA==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "is-arguments": "^1.0.4", - "is-generator-function": "^1.0.7", - "is-typed-array": "^1.1.3", - "which-typed-array": "^1.1.2" - } - }, "util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", @@ -10476,125 +7818,6 @@ "spdx-expression-parse": "^3.0.0" } }, - "vm-browserify": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/vm-browserify/-/vm-browserify-1.1.2.tgz", - "integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==", - "dev": true - }, - "watchpack": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", - "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", - "dev": true, - "requires": { - "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.1.2" - } - }, - "webpack": { - "version": "5.76.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.76.0.tgz", - "integrity": "sha512-l5sOdYBDunyf72HW8dF23rFtWq/7Zgvt/9ftMof71E/yUb1YLOBmTgA2K4vQthB3kotMrSj609txVE0dnr2fjA==", - "dev": true, - "requires": { - "@types/eslint-scope": "^3.7.3", - "@types/estree": "^0.0.51", - "@webassemblyjs/ast": "1.11.1", - "@webassemblyjs/wasm-edit": "1.11.1", - "@webassemblyjs/wasm-parser": "1.11.1", - "acorn": "^8.7.1", - "acorn-import-assertions": "^1.7.6", - "browserslist": "^4.14.5", - "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.10.0", - "es-module-lexer": "^0.9.0", - "eslint-scope": "5.1.1", - "events": "^3.2.0", - "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.2.9", - "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", - "mime-types": "^2.1.27", - "neo-async": "^2.6.2", - "schema-utils": "^3.1.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.1.3", - "watchpack": "^2.4.0", - "webpack-sources": "^3.2.3" - } - }, - "webpack-bundle-analyzer": { - "version": "4.8.0", - "resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.8.0.tgz", - "integrity": "sha512-ZzoSBePshOKhr+hd8u6oCkZVwpVaXgpw23ScGLFpR6SjYI7+7iIWYarjN6OEYOfRt8o7ZyZZQk0DuMizJ+LEIg==", - "dev": true, - "requires": { - "@discoveryjs/json-ext": "0.5.7", - "acorn": "^8.0.4", - "acorn-walk": "^8.0.0", - "chalk": "^4.1.0", - "commander": "^7.2.0", - "gzip-size": "^6.0.0", - "lodash": "^4.17.20", - "opener": "^1.5.2", - "sirv": "^1.0.7", - "ws": "^7.3.1" - }, - "dependencies": { - "commander": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz", - "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==", - "dev": true - } - } - }, - "webpack-cli": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/webpack-cli/-/webpack-cli-5.0.1.tgz", - "integrity": "sha512-S3KVAyfwUqr0Mo/ur3NzIp6jnerNpo7GUO6so51mxLi1spqsA17YcMXy0WOIJtBSnj748lthxC6XLbNKh/ZC+A==", - "dev": true, - "requires": { - "@discoveryjs/json-ext": "^0.5.0", - "@webpack-cli/configtest": "^2.0.1", - "@webpack-cli/info": "^2.0.1", - "@webpack-cli/serve": "^2.0.1", - "colorette": "^2.0.14", - "commander": "^9.4.1", - "cross-spawn": "^7.0.3", - "envinfo": "^7.7.3", - "fastest-levenshtein": "^1.0.12", - "import-local": "^3.0.2", - "interpret": "^3.1.1", - "rechoir": "^0.8.0", - "webpack-merge": "^5.7.3" - }, - "dependencies": { - "commander": { - "version": "9.5.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-9.5.0.tgz", - "integrity": "sha512-KRs7WVDKg86PWiuAqhDrAQnTXZKraVcCc6vFdL14qrZ/DcWwuRo7VoiYXalXO7S5GKpqYiVEwCbgFDfxNHKJBQ==", - "dev": true - } - } - }, - "webpack-merge": { - "version": "5.8.0", - "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.8.0.tgz", - "integrity": "sha512-/SaI7xY0831XwP6kzuwhKWVKDP9t1QY1h65lAFLbZqMPIuYcD9QAW4u9STIbU9kaJbPBB/geU/gLr1wDjOhQ+Q==", - "dev": true, - "requires": { - "clone-deep": "^4.0.1", - "wildcard": "^2.0.0" - } - }, - "webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", - "dev": true - }, "which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -10618,17 +7841,16 @@ } }, "which-typed-array": { - "version": "1.1.9", - "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.9.tgz", - "integrity": "sha512-w9c4xkx6mPidwp7180ckYWfMmvxpjlZuIudNtDf4N/tTAUB8VJbX25qZoAsrtGuYNnGw3pa0AXgbGKRB8/EceA==", + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.11.tgz", + "integrity": "sha512-qe9UWWpkeG5yzZ0tNYxDmd7vo58HDBc39mZ0xWWpolAGADdFOzkfamWLDxkOWcvHQKVmdTyQdLD4NOfjLWTKew==", "dev": true, "requires": { "available-typed-arrays": "^1.0.5", "call-bind": "^1.0.2", "for-each": "^0.3.3", "gopd": "^1.0.1", - "has-tostringtag": "^1.0.0", - "is-typed-array": "^1.1.10" + "has-tostringtag": "^1.0.0" } }, "wide-align": { @@ -10640,28 +7862,6 @@ "string-width": "^1.0.2 || 2 || 3 || 4" } }, - "wildcard": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/wildcard/-/wildcard-2.0.0.tgz", - "integrity": "sha512-JcKqAHLPxcdb9KM49dufGXn2x3ssnfjbcaQdLlfZsL9rH9wgDQjUtDxbo8NE0F6SFvydeu1VhZe7hZuHsB2/pw==", - "dev": true - }, - "word-wrap": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.4.tgz", - "integrity": "sha512-2V81OA4ugVo5pRo46hAoD2ivUJx8jXmWXfUkY4KFNw0hEptvN0QfH3K4nHiwzGeKl5rFKedV48QVoqYavy4YpA==", - "dev": true - }, - "worker-loader": { - "version": "3.0.8", - "resolved": "https://registry.npmjs.org/worker-loader/-/worker-loader-3.0.8.tgz", - "integrity": "sha512-XQyQkIFeRVC7f7uRhFdNMe/iJOdO6zxAaR3EWbDp45v3mDhrTi+++oswKNxShUNjPC/1xUp5DB29YKLhFo129g==", - "dev": true, - "requires": { - "loader-utils": "^2.0.0", - "schema-utils": "^3.0.0" - } - }, "workerpool": { "version": "6.2.1", "resolved": "https://registry.npmjs.org/workerpool/-/workerpool-6.2.1.tgz", @@ -10685,19 +7885,6 @@ "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", "dev": true }, - "ws": { - "version": "7.5.9", - "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", - "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", - "dev": true, - "requires": {} - }, - "xtend": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", - "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", - "dev": true - }, "y18n": { "version": "5.0.8", "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", diff --git a/js/package.json b/js/package.json index 82d644ae6570f..d2c689265c6ad 100644 --- a/js/package.json +++ b/js/package.json @@ -1,32 +1,27 @@ { "devDependencies": { - "@types/fs-extra": "^11.0.1", - "@types/mocha": "^10.0.1", + "@types/fs-extra": "^11.0.2", + "@types/mocha": "^10.0.2", "@types/node": "^18.14.6", "@types/npmlog": "^4.1.4", - "@typescript-eslint/eslint-plugin": "^5.54.1", - "@typescript-eslint/parser": "^5.54.1", + "@typescript-eslint/eslint-plugin": "^6.7.4", + "@typescript-eslint/parser": "^6.7.4", "clang-format": "^1.8.0", - "dir-compare": "^4.0.0", - "eslint": "^8.35.0", + "dir-compare": "^4.2.0", + "esbuild": "^0.19.3", + "esbuild-plugin-polyfill-node": "^0.3.0", + "eslint": "^8.51.0", "eslint-plugin-header": "^3.1.1", - "eslint-plugin-import": "^2.27.5", - "eslint-plugin-jsdoc": "^40.0.1", + "eslint-plugin-import": "^2.28.1", + "eslint-plugin-jsdoc": "^46.8.2", "eslint-plugin-prefer-arrow": "^1.2.3", - "eslint-plugin-unicorn": "^46.0.0", - "fs-extra": "^11.1.0", + "eslint-plugin-unicorn": "^48.0.1", + "fs-extra": "^11.1.1", "jszip": "^3.10.1", "mocha": "^10.2.0", - "node-polyfill-webpack-plugin": "^2.0.1", "npmlog": "^7.0.1", - "prettier": "^3.0.0", - "terser": "^5.16.5", - "ts-loader": "^9.4.2", - "typescript": "^4.9.5", - "webpack": "^5.76.0", - "webpack-bundle-analyzer": "^4.8.0", - "webpack-cli": "^5.0.1", - "worker-loader": "^3.0.8" + "prettier": "^3.0.3", + "typescript": "^5.2.2" }, "scripts": { "prepare": "tsc --build scripts", diff --git a/js/react_native/e2e/package.json b/js/react_native/e2e/package.json index 969c70c110123..cd97ec1d099e4 100644 --- a/js/react_native/e2e/package.json +++ b/js/react_native/e2e/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^18.1.0", - "react-native": "^0.69.1" + "react-native": "^0.69.1", + "react-native-fs": "^2.20.0" }, "devDependencies": { "@babel/core": "^7.17.0", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index f3e415f0c5a55..8a76edabc613e 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -8,6 +8,7 @@ import { Image, Text, TextInput, View } from 'react-native'; import { InferenceSession, Tensor } from 'onnxruntime-react-native'; import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; import { Buffer } from 'buffer'; +import { readFile } from 'react-native-fs'; interface State { session: @@ -39,10 +40,21 @@ export default class App extends React.PureComponent<{}, State> { this.setState({ imagePath }); const modelPath = await MNIST.getLocalModelPath(); - const session: InferenceSession = await InferenceSession.create(modelPath); + + // test creating session with path + console.log('Creating with path'); + const pathSession: InferenceSession = await InferenceSession.create(modelPath); + pathSession.release(); + + // and with bytes + console.log('Creating with bytes'); + const base64Str = await readFile(modelPath, 'base64'); + const bytes = Buffer.from(base64Str, 'base64'); + const session: InferenceSession = await InferenceSession.create(bytes); this.setState({ session }); - void this.infer(); + console.log('Test session created'); + void await this.infer(); } catch (err) { console.log(err.message); } diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index aaa35ae7895b9..9e20a286c4e27 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -39,6 +39,14 @@ dependencies: "@babel/highlight" "^7.18.6" +"@babel/code-frame@^7.22.13": + version "7.22.13" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e" + integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w== + dependencies: + "@babel/highlight" "^7.22.13" + chalk "^2.4.2" + "@babel/compat-data@^7.13.11", "@babel/compat-data@^7.17.10": version "7.18.5" resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.18.5.tgz#acac0c839e317038c73137fbb6ef71a1d6238471" @@ -100,15 +108,6 @@ "@jridgewell/gen-mapping" "^0.3.0" jsesc "^2.5.1" -"@babel/generator@^7.18.7": - version "7.18.7" - resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.18.7.tgz#2aa78da3c05aadfc82dbac16c99552fc802284bd" - integrity sha512-shck+7VLlY72a2w9c3zYWuE1pwOKEiQHV7GTUbSnhyl5eu3i04t30tBY82ZRWrDfo3gkakCFtevExnxbkf2a3A== - dependencies: - "@babel/types" "^7.18.7" - "@jridgewell/gen-mapping" "^0.3.2" - jsesc "^2.5.1" - "@babel/generator@^7.21.4", "@babel/generator@^7.7.2": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.21.4.tgz#64a94b7448989f421f919d5239ef553b37bb26bc" @@ -119,6 +118,16 @@ "@jridgewell/trace-mapping" "^0.3.17" jsesc "^2.5.1" +"@babel/generator@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420" + integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g== + dependencies: + "@babel/types" "^7.23.0" + "@jridgewell/gen-mapping" "^0.3.2" + "@jridgewell/trace-mapping" "^0.3.17" + jsesc "^2.5.1" + "@babel/helper-annotate-as-pure@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.16.7.tgz#bb2339a7534a9c128e3102024c60760a3a7f3862" @@ -220,6 +229,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz#0c0cee9b35d2ca190478756865bb3528422f51be" integrity sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg== +"@babel/helper-environment-visitor@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167" + integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA== + "@babel/helper-explode-assignable-expression@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.16.7.tgz#12a6d8522fdd834f194e868af6354e8650242b7a" @@ -243,27 +257,20 @@ "@babel/template" "^7.18.6" "@babel/types" "^7.18.6" -"@babel/helper-function-name@^7.21.0": - version "7.21.0" - resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.21.0.tgz#d552829b10ea9f120969304023cd0645fa00b1b4" - integrity sha512-HfK1aMRanKHpxemaY2gqBmL04iAPOPRj7DxtNbiDOrJK+gdwkiNRVpCpUJYbUT+aZyemKN8brqTOxzCaG6ExRg== - dependencies: - "@babel/template" "^7.20.7" - "@babel/types" "^7.21.0" - -"@babel/helper-hoist-variables@^7.16.7": - version "7.16.7" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.16.7.tgz#86bcb19a77a509c7b77d0e22323ef588fa58c246" - integrity sha512-m04d/0Op34H5v7pbZw6pSKP7weA6lsMvfiIAMeIvkY/R4xQtBSMFEigu9QTZ2qB/9l22vsxtM8a+Q8CzD255fg== +"@babel/helper-function-name@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759" + integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw== dependencies: - "@babel/types" "^7.16.7" + "@babel/template" "^7.22.15" + "@babel/types" "^7.23.0" -"@babel/helper-hoist-variables@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678" - integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q== +"@babel/helper-hoist-variables@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb" + integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" "@babel/helper-member-expression-to-functions@^7.17.7": version "7.17.7" @@ -401,11 +408,23 @@ dependencies: "@babel/types" "^7.18.6" +"@babel/helper-split-export-declaration@^7.22.6": + version "7.22.6" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c" + integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g== + dependencies: + "@babel/types" "^7.22.5" + "@babel/helper-string-parser@^7.19.4": version "7.19.4" resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz#38d3acb654b4701a9b77fb0615a96f775c3a9e63" integrity sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw== +"@babel/helper-string-parser@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f" + integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw== + "@babel/helper-validator-identifier@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.16.7.tgz#e8c602438c4a8195751243da9031d1607d247cad" @@ -421,6 +440,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz#7eea834cf32901ffdc1a7ee555e2f9c27e249ca2" integrity sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w== +"@babel/helper-validator-identifier@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0" + integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A== + "@babel/helper-validator-option@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.16.7.tgz#b203ce62ce5fe153899b617c08957de860de4d23" @@ -487,6 +511,15 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@babel/highlight@^7.22.13": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54" + integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg== + dependencies: + "@babel/helper-validator-identifier" "^7.22.20" + chalk "^2.4.2" + js-tokens "^4.0.0" + "@babel/parser@^7.1.0", "@babel/parser@^7.14.7", "@babel/parser@^7.20.7", "@babel/parser@^7.21.4": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.21.4.tgz#94003fdfc520bbe2875d4ae557b43ddb6d880f17" @@ -497,11 +530,16 @@ resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.5.tgz#337062363436a893a2d22faa60be5bb37091c83c" integrity sha512-YZWVaglMiplo7v8f1oMQ5ZPQr0vn7HPeZXxXWsxXJRjGVrzUFn9OxFQl1sb5wzfootjA/yChhW84BV+383FSOw== -"@babel/parser@^7.18.6", "@babel/parser@^7.18.8": +"@babel/parser@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.8.tgz#822146080ac9c62dac0823bb3489622e0bc1cbdf" integrity sha512-RSKRfYX20dyH+elbJK2uqAkVyucL+xXzhqlMD5/ZXx+dAAwpyB7HsvnHe/ZUGOF+xLr5Wx9/JoXVTj6BQE2/oA== +"@babel/parser@^7.22.15", "@babel/parser@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719" + integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw== + "@babel/plugin-proposal-async-generator-functions@^7.0.0": version "7.18.6" resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-async-generator-functions/-/plugin-proposal-async-generator-functions-7.18.6.tgz#aedac81e6fc12bb643374656dd5f2605bf743d17" @@ -1016,51 +1054,28 @@ "@babel/parser" "^7.20.7" "@babel/types" "^7.20.7" -"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.5": - version "7.18.5" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.5.tgz#94a8195ad9642801837988ab77f36e992d9a20cd" - integrity sha512-aKXj1KT66sBj0vVzk6rEeAO6Z9aiiQ68wfDgge3nHhA/my6xMM/7HGQUNumKZaoa2qUPQ5whJG9aAifsxUKfLA== - dependencies: - "@babel/code-frame" "^7.16.7" - "@babel/generator" "^7.18.2" - "@babel/helper-environment-visitor" "^7.18.2" - "@babel/helper-function-name" "^7.17.9" - "@babel/helper-hoist-variables" "^7.16.7" - "@babel/helper-split-export-declaration" "^7.16.7" - "@babel/parser" "^7.18.5" - "@babel/types" "^7.18.4" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.18.6": - version "7.18.8" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.8.tgz#f095e62ab46abf1da35e5a2011f43aee72d8d5b0" - integrity sha512-UNg/AcSySJYR/+mIcJQDCv00T+AqRO7j/ZEJLzpaYtgM48rMg5MnkJgyNqkzo88+p4tfRvZJCEiwwfG6h4jkRg== - dependencies: - "@babel/code-frame" "^7.18.6" - "@babel/generator" "^7.18.7" - "@babel/helper-environment-visitor" "^7.18.6" - "@babel/helper-function-name" "^7.18.6" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.18.8" - "@babel/types" "^7.18.8" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.7.2": - version "7.21.4" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.21.4.tgz#a836aca7b116634e97a6ed99976236b3282c9d36" - integrity sha512-eyKrRHKdyZxqDm+fV1iqL9UAHMoIg0nDaGqfIOd8rKH17m5snv7Gn4qgjBoFfLz9APvjFU/ICT00NVCv1Epp8Q== - dependencies: - "@babel/code-frame" "^7.21.4" - "@babel/generator" "^7.21.4" - "@babel/helper-environment-visitor" "^7.18.9" - "@babel/helper-function-name" "^7.21.0" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.21.4" - "@babel/types" "^7.21.4" +"@babel/template@^7.22.15": + version "7.22.15" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38" + integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/parser" "^7.22.15" + "@babel/types" "^7.22.15" + +"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.5", "@babel/traverse@^7.18.6", "@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.7.2": + version "7.23.2" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" + integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/generator" "^7.23.0" + "@babel/helper-environment-visitor" "^7.22.20" + "@babel/helper-function-name" "^7.23.0" + "@babel/helper-hoist-variables" "^7.22.5" + "@babel/helper-split-export-declaration" "^7.22.6" + "@babel/parser" "^7.23.0" + "@babel/types" "^7.23.0" debug "^4.1.0" globals "^11.1.0" @@ -1072,7 +1087,7 @@ "@babel/helper-validator-identifier" "^7.16.7" to-fast-properties "^2.0.0" -"@babel/types@^7.18.6", "@babel/types@^7.18.7", "@babel/types@^7.18.8": +"@babel/types@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.18.8.tgz#c5af199951bf41ba4a6a9a6d0d8ad722b30cd42f" integrity sha512-qwpdsmraq0aJ3osLJRApsc2ouSJCdnMeZwB0DhbtHAtRpZNZCdlbRnHIgcRKzdE1g0iOGg644fzjOBcdOz9cPw== @@ -1089,6 +1104,15 @@ "@babel/helper-validator-identifier" "^7.19.1" to-fast-properties "^2.0.0" +"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb" + integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg== + dependencies: + "@babel/helper-string-parser" "^7.22.5" + "@babel/helper-validator-identifier" "^7.22.20" + to-fast-properties "^2.0.0" + "@bcoe/v8-coverage@^0.2.3": version "0.2.3" resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" @@ -2050,6 +2074,11 @@ balanced-match@^1.0.0: resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee" integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw== +base-64@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/base-64/-/base-64-0.1.0.tgz#780a99c84e7d600260361511c4877613bf24f6bb" + integrity sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA== + base64-js@^1.1.2, base64-js@^1.3.1, base64-js@^1.5.1: version "1.5.1" resolved "https://registry.yarnpkg.com/base64-js/-/base64-js-1.5.1.tgz#1b1b440160a5bf7ad40b650f095963481903930a" @@ -2260,7 +2289,7 @@ caniuse-lite@^1.0.30001449: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001478.tgz#0ef8a1cf8b16be47a0f9fc4ecfc952232724b32a" integrity sha512-gMhDyXGItTHipJj2ApIvR+iVB5hd0KP3svMWWXDvZOmjzJJassGLMfxRkQCSYgGd2gtdL/ReeiyvMSFD1Ss6Mw== -chalk@^2.0.0: +chalk@^2.0.0, chalk@^2.4.2: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== @@ -5205,6 +5234,14 @@ react-native-codegen@^0.69.2: jscodeshift "^0.13.1" nullthrows "^1.1.1" +react-native-fs@^2.20.0: + version "2.20.0" + resolved "https://registry.yarnpkg.com/react-native-fs/-/react-native-fs-2.20.0.tgz#05a9362b473bfc0910772c0acbb73a78dbc810f6" + integrity sha512-VkTBzs7fIDUiy/XajOSNk0XazFE9l+QlMAce7lGuebZcag5CnjszB+u4BdqzwaQOdcYb5wsJIsqq4kxInIRpJQ== + dependencies: + base-64 "^0.1.0" + utf8 "^3.0.0" + react-native-gradle-plugin@^0.0.7: version "0.0.7" resolved "https://registry.yarnpkg.com/react-native-gradle-plugin/-/react-native-gradle-plugin-0.0.7.tgz#96602f909745239deab7b589443f14fce5da2056" @@ -6167,6 +6204,11 @@ utf8-byte-length@^1.0.1: resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" integrity sha512-4+wkEYLBbWxqTahEsWrhxepcoVOJ+1z5PGIjPZxRkytcdSUaNjIjBM7Xn8E+pdSuV7SzvWovBFA54FO0JSoqhA== +utf8@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/utf8/-/utf8-3.0.0.tgz#f052eed1364d696e769ef058b183df88c87f69d1" + integrity sha512-E8VjFIQ/TyQgp+TZfS6l8yp/xWppSAHzidGiRrqe4bK4XP9pTRyKFgGJpO3SN7zdX4DeomTrwaseCHovfpFcqQ== + util-deprecate@^1.0.1, util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c466308a5..3d3569028e636 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; +import {type Backend, InferenceSession, type InferenceSessionHandler, type SessionHandler, Tensor} from 'onnxruntime-common'; import {Platform} from 'react-native'; -import {binding, Binding, JSIBlob, jsiHelper} from './binding'; +import {binding, type Binding, type JSIBlob, jsiHelper} from './binding'; type SupportedTypedArray = Exclude; @@ -43,7 +43,7 @@ const normalizePath = (path: string): string => { return path; }; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; #key: string; @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created @@ -163,8 +165,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeSessionHandler(pathOrBuffer); await handler.loadModel(options || {}); return handler; diff --git a/js/react_native/tsconfig.json b/js/react_native/tsconfig.json index 9e929c4530982..2817d5512a186 100644 --- a/js/react_native/tsconfig.json +++ b/js/react_native/tsconfig.json @@ -7,7 +7,6 @@ "allowUnreachableCode": false, "allowUnusedLabels": false, "esModuleInterop": true, - "importsNotUsedAsValues": "error", "forceConsistentCasingInFileNames": true, "jsx": "react", "lib": ["esnext"], diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 0b47158ab705b..ff9be7fbe3a5b 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -24,6 +24,14 @@ dependencies: "@babel/highlight" "^7.18.6" +"@babel/code-frame@^7.22.13": + version "7.22.13" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e" + integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w== + dependencies: + "@babel/highlight" "^7.22.13" + chalk "^2.4.2" + "@babel/code-frame@~7.10.4": version "7.10.4" resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.10.4.tgz#168da1a36e90da68ae8d49c0f1b48c7c6249213a" @@ -66,13 +74,14 @@ "@jridgewell/gen-mapping" "^0.3.0" jsesc "^2.5.1" -"@babel/generator@^7.18.7": - version "7.18.7" - resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.18.7.tgz#2aa78da3c05aadfc82dbac16c99552fc802284bd" - integrity sha512-shck+7VLlY72a2w9c3zYWuE1pwOKEiQHV7GTUbSnhyl5eu3i04t30tBY82ZRWrDfo3gkakCFtevExnxbkf2a3A== +"@babel/generator@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420" + integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g== dependencies: - "@babel/types" "^7.18.7" + "@babel/types" "^7.23.0" "@jridgewell/gen-mapping" "^0.3.2" + "@jridgewell/trace-mapping" "^0.3.17" jsesc "^2.5.1" "@babel/helper-annotate-as-pure@^7.16.7": @@ -160,6 +169,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.6.tgz#b7eee2b5b9d70602e59d1a6cad7dd24de7ca6cd7" integrity sha512-8n6gSfn2baOY+qlp+VSzsosjCVGFqWKmDF0cCWOybh52Dw3SEyoWR1KrhMJASjLwIEkkAufZ0xvr+SxLHSpy2Q== +"@babel/helper-environment-visitor@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167" + integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA== + "@babel/helper-explode-assignable-expression@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.16.7.tgz#12a6d8522fdd834f194e868af6354e8650242b7a" @@ -183,6 +197,14 @@ "@babel/template" "^7.18.6" "@babel/types" "^7.18.6" +"@babel/helper-function-name@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759" + integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw== + dependencies: + "@babel/template" "^7.22.15" + "@babel/types" "^7.23.0" + "@babel/helper-hoist-variables@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.16.7.tgz#86bcb19a77a509c7b77d0e22323ef588fa58c246" @@ -190,12 +212,12 @@ dependencies: "@babel/types" "^7.16.7" -"@babel/helper-hoist-variables@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678" - integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q== +"@babel/helper-hoist-variables@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb" + integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" "@babel/helper-member-expression-to-functions@^7.17.7": version "7.17.7" @@ -293,12 +315,17 @@ dependencies: "@babel/types" "^7.16.7" -"@babel/helper-split-export-declaration@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.18.6.tgz#7367949bc75b20c6d5a5d4a97bba2824ae8ef075" - integrity sha512-bde1etTx6ZyTmobl9LLMMQsaizFVZrquTEHOqKeQESMKo4PlObf+8+JA25ZsIpZhT/WEd39+vOdLXAFG/nELpA== +"@babel/helper-split-export-declaration@^7.22.6": + version "7.22.6" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c" + integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" + +"@babel/helper-string-parser@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f" + integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw== "@babel/helper-validator-identifier@^7.16.7": version "7.19.1" @@ -310,6 +337,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.18.6.tgz#9c97e30d31b2b8c72a1d08984f2ca9b574d7a076" integrity sha512-MmetCkz9ej86nJQV+sFCxoGGrUbU3q02kgLciwkrt9QqEB7cP39oKEY0PakknEO0Gu20SskMRi+AYZ3b1TpN9g== +"@babel/helper-validator-identifier@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0" + integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A== + "@babel/helper-validator-option@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.16.7.tgz#b203ce62ce5fe153899b617c08957de860de4d23" @@ -362,16 +394,30 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@babel/highlight@^7.22.13": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54" + integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg== + dependencies: + "@babel/helper-validator-identifier" "^7.22.20" + chalk "^2.4.2" + js-tokens "^4.0.0" + "@babel/parser@^7.1.0", "@babel/parser@^7.13.16", "@babel/parser@^7.14.0", "@babel/parser@^7.14.7", "@babel/parser@^7.16.7", "@babel/parser@^7.18.0": version "7.18.3" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.3.tgz#39e99c7b0c4c56cef4d1eed8de9f506411c2ebc2" integrity sha512-rL50YcEuHbbauAFAysNsJA4/f89fGTOBRNs9P81sniKnKAr4xULe5AecolcsKbi88xu0ByWYDj/S1AJ3FSFuSQ== -"@babel/parser@^7.18.6", "@babel/parser@^7.18.8": +"@babel/parser@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.8.tgz#822146080ac9c62dac0823bb3489622e0bc1cbdf" integrity sha512-RSKRfYX20dyH+elbJK2uqAkVyucL+xXzhqlMD5/ZXx+dAAwpyB7HsvnHe/ZUGOF+xLr5Wx9/JoXVTj6BQE2/oA== +"@babel/parser@^7.22.15", "@babel/parser@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719" + integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw== + "@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@^7.17.12": version "7.17.12" resolved "https://registry.yarnpkg.com/@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression/-/plugin-bugfix-safari-id-destructuring-collision-in-function-expression-7.17.12.tgz#1dca338caaefca368639c9ffb095afbd4d420b1e" @@ -1182,35 +1228,28 @@ "@babel/parser" "^7.18.6" "@babel/types" "^7.18.6" -"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.7.2": - version "7.18.2" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.2.tgz#b77a52604b5cc836a9e1e08dca01cba67a12d2e8" - integrity sha512-9eNwoeovJ6KH9zcCNnENY7DMFwTU9JdGCFtqNLfUAqtUHRCOsTOqWoffosP8vKmNYeSBUv3yVJXjfd8ucwOjUA== - dependencies: - "@babel/code-frame" "^7.16.7" - "@babel/generator" "^7.18.2" - "@babel/helper-environment-visitor" "^7.18.2" - "@babel/helper-function-name" "^7.17.9" - "@babel/helper-hoist-variables" "^7.16.7" - "@babel/helper-split-export-declaration" "^7.16.7" - "@babel/parser" "^7.18.0" - "@babel/types" "^7.18.2" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.18.6": - version "7.18.8" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.8.tgz#f095e62ab46abf1da35e5a2011f43aee72d8d5b0" - integrity sha512-UNg/AcSySJYR/+mIcJQDCv00T+AqRO7j/ZEJLzpaYtgM48rMg5MnkJgyNqkzo88+p4tfRvZJCEiwwfG6h4jkRg== - dependencies: - "@babel/code-frame" "^7.18.6" - "@babel/generator" "^7.18.7" - "@babel/helper-environment-visitor" "^7.18.6" - "@babel/helper-function-name" "^7.18.6" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.18.8" - "@babel/types" "^7.18.8" +"@babel/template@^7.22.15": + version "7.22.15" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38" + integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/parser" "^7.22.15" + "@babel/types" "^7.22.15" + +"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.6", "@babel/traverse@^7.7.2": + version "7.23.2" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" + integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/generator" "^7.23.0" + "@babel/helper-environment-visitor" "^7.22.20" + "@babel/helper-function-name" "^7.23.0" + "@babel/helper-hoist-variables" "^7.22.5" + "@babel/helper-split-export-declaration" "^7.22.6" + "@babel/parser" "^7.23.0" + "@babel/types" "^7.23.0" debug "^4.1.0" globals "^11.1.0" @@ -1222,7 +1261,7 @@ "@babel/helper-validator-identifier" "^7.16.7" to-fast-properties "^2.0.0" -"@babel/types@^7.18.6", "@babel/types@^7.18.7", "@babel/types@^7.18.8": +"@babel/types@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.18.8.tgz#c5af199951bf41ba4a6a9a6d0d8ad722b30cd42f" integrity sha512-qwpdsmraq0aJ3osLJRApsc2ouSJCdnMeZwB0DhbtHAtRpZNZCdlbRnHIgcRKzdE1g0iOGg644fzjOBcdOz9cPw== @@ -1230,6 +1269,15 @@ "@babel/helper-validator-identifier" "^7.18.6" to-fast-properties "^2.0.0" +"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb" + integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg== + dependencies: + "@babel/helper-string-parser" "^7.22.5" + "@babel/helper-validator-identifier" "^7.22.20" + to-fast-properties "^2.0.0" + "@bcoe/v8-coverage@^0.2.3": version "0.2.3" resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" @@ -1530,6 +1578,11 @@ resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.0.7.tgz#30cd49820a962aff48c8fffc5cd760151fca61fe" integrity sha512-8cXDaBBHOr2pQ7j77Y6Vp5VDT2sIqWyWQ56TjEq4ih/a4iST3dItRe8Q9fp0rrIl9DoKhWQtUQz/YpOxLkXbNA== +"@jridgewell/resolve-uri@^3.1.0": + version "3.1.1" + resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz#c08679063f279615a3326583ba3a90d1d82cc721" + integrity sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA== + "@jridgewell/set-array@^1.0.0": version "1.1.1" resolved "https://registry.yarnpkg.com/@jridgewell/set-array/-/set-array-1.1.1.tgz#36a6acc93987adcf0ba50c66908bd0b70de8afea" @@ -1545,6 +1598,19 @@ resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.13.tgz#b6461fb0c2964356c469e115f504c95ad97ab88c" integrity sha512-GryiOJmNcWbovBxTfZSF71V/mXbgcV3MewDe3kIMCLyIh5e7SKAeUZs+rMnJ8jkMolZ/4/VsdBmMrw3l+VdZ3w== +"@jridgewell/sourcemap-codec@^1.4.14": + version "1.4.15" + resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz#d7c6e6755c78567a951e04ab52ef0fd26de59f32" + integrity sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg== + +"@jridgewell/trace-mapping@^0.3.17": + version "0.3.19" + resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz#f8a3249862f91be48d3127c3cfe992f79b4b8811" + integrity sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw== + dependencies: + "@jridgewell/resolve-uri" "^3.1.0" + "@jridgewell/sourcemap-codec" "^1.4.14" + "@jridgewell/trace-mapping@^0.3.9": version "0.3.13" resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.13.tgz#dcfe3e95f224c8fe97a87a5235defec999aa92ea" @@ -2470,7 +2536,7 @@ caniuse-lite@^1.0.30001332: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001342.tgz#87152b1e3b950d1fbf0093e23f00b6c8e8f1da96" integrity sha512-bn6sOCu7L7jcbBbyNhLg0qzXdJ/PMbybZTH/BA6Roet9wxYRm6Tr9D0s0uhLkOZ6MSG+QU6txUgdpr3MXIVqjA== -chalk@^2.0.0: +chalk@^2.0.0, chalk@^2.4.2: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== diff --git a/js/tsconfig.json b/js/tsconfig.json index d199cb9d0a6f5..faf9066b3d96e 100644 --- a/js/tsconfig.json +++ b/js/tsconfig.json @@ -1,7 +1,7 @@ { "compilerOptions": { - "module": "ES2015", - "moduleResolution": "node16", + "module": "Node16", + "moduleResolution": "Node16", "esModuleInterop": true, "target": "ES2020", "lib": ["ES2020", "ESNext.BigInt", "dom"], @@ -10,7 +10,7 @@ "noImplicitAny": true, "noImplicitReturns": true, "noImplicitThis": true, - "noUnusedParameters": false, + "noUnusedParameters": true, "alwaysStrict": true, "strictNullChecks": true, "pretty": true, diff --git a/js/tsconfig.tools.json b/js/tsconfig.tools.json index e55ef7e3eb57c..a70ca0388034d 100644 --- a/js/tsconfig.tools.json +++ b/js/tsconfig.tools.json @@ -1,7 +1,6 @@ { "extends": "./tsconfig.json", "compilerOptions": { - "module": "CommonJS", "declaration": false, "sourceMap": false } diff --git a/js/web/.npmignore b/js/web/.npmignore index 8e08db5917149..0f018f525a8d6 100644 --- a/js/web/.npmignore +++ b/js/web/.npmignore @@ -4,11 +4,30 @@ /dist/**/*.report.html +# We remove some of the files in NPM packages because restrictions in jsdelivr: +# +# "Packages larger than 150 MB or single files larger than 20 MB (in the case of GitHub) are not supported" +# +# from https://www.jsdelivr.com/documentation +# +# We only include development build in the NPM package for the following targets: +# - /dist/ort.js +# - /dist/ort.all.js +# +/dist/cjs/ort.js +/dist/esm/ort.js +/dist/cjs/ort.all.js +/dist/esm/ort.all.js +/dist/**/ort.wasm.js +/dist/**/ort.wasm-core.js +/dist/**/ort.webgl.js +/dist/**/ort.webgpu.js +/dist/**/ort.training.wasm.js + /types/ karma.conf.js tsconfig.json tsconfig.tsbuildinfo -webpack.config.js *.tgz diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index de84134ddbb3f..7c129b66bfa3d 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -12,6 +12,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Acos](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acos) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7) | | [Acosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acosh) | | | [Add](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-14) | +| [AffineGrid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AffineGrid) | | | [And](https://github.com/onnx/onnx/blob/main/docs/Operators.md#And) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#And-7) | | [ArgMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMax) | | | [ArgMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMin) | | @@ -67,6 +68,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | | [GatherElements](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) | | | [GatherND](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND) | | +| [Gelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gelu) | | | [Gemm](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-7), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13) | | [GlobalAveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool) | [1+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalAveragePool-1) | | [GlobalLpPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalLpPool) | | @@ -82,6 +84,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | | [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | +| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | | [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | | | [IsNaN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsNaN) | | @@ -137,12 +140,13 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceL2](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2) | | | [ReduceLogSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-18) | | [ReduceLogSumExp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSumExp) | | -| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18) | +| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-20) | | [ReduceMean](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-18) | -| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18) | +| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-20) | | [ReduceProd](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceProd) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-18) | | [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) | | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | +| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | | [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | @@ -179,7 +183,9 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | | [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) | +| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | +| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | | [Sub](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14) | | [Sum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) | [6-7](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-6), [8-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-8), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-13) | | [Tan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7) | diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..00c27fe3ab034 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,13 +20,17 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | -| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | +| Attention | com.microsoft(1+) | need implementing mask and past/present | +| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | +| BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | | +| BiasAdd | com.microsoft(1+) | | +| BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | | Concat | ai.onnx(1-3,4-10,11-12,13+) | | -| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation | -| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | +| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; conv3d is not supported; need implementing activation | +| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | | Div | ai.onnx(7-12,13,14+) | | @@ -38,6 +42,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | @@ -54,14 +59,16 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | -| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation | +| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | @@ -75,7 +82,7 @@ Do not modify directly.* | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | -| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | @@ -93,3 +100,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 35f782d1fdca3..8fce79843f617 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -4,7 +4,7 @@ 'use strict'; const args = require('minimist')(process.argv, {}); -const bundleMode = args['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined; +const bundleMode = args['bundle-mode'] || 'dev'; // 'dev'|'perf' const karmaPlugins = args['karma-plugins'] || undefined; const timeoutMocha = args['timeout-mocha'] || 60000; const forceLocalHost = !!args['force-localhost']; @@ -19,8 +19,8 @@ if (!chromiumFlags) { throw new Error(`Invalid command line arg: --chromium-flags: ${chromiumFlags}`); } -const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js' -const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js'; +const ORT_FILE = bundleMode === 'dev' ? 'dist/ort.all.js' : 'dist/ort.all.min.js'; +const TEST_FILE = bundleMode === 'dev' ? 'test/ort.test.js' : 'test/ort.test.min.js'; // it's a known issue that Safari does not work with "localhost" in BrowserStack: // https://www.browserstack.com/question/663 @@ -67,30 +67,14 @@ module.exports = function(config) { }, frameworks: ['mocha'], files: [ - {pattern: commonFile}, - {pattern: mainFile}, - {pattern: 'test/testdata-file-cache-*.json', included: false}, - {pattern: 'test/data/**/*', included: false, nocache: true}, - {pattern: 'dist/ort-wasm.wasm', included: false}, - {pattern: 'dist/ort-wasm-threaded.wasm', included: false}, - {pattern: 'dist/ort-wasm-simd.wasm', included: false}, - {pattern: 'dist/ort-wasm-simd-threaded.wasm', included: false}, - {pattern: 'dist/ort-wasm-simd.jsep.wasm', included: false}, - {pattern: 'dist/ort-wasm-simd-threaded.jsep.wasm', included: false}, - {pattern: 'dist/ort-wasm-threaded.worker.js', included: false}, + {pattern: ORT_FILE}, + {pattern: TEST_FILE}, + {pattern: 'test/testdata-file-cache-*.json', included: false, watched: false}, + {pattern: 'test/data/**/*', included: false, nocache: true, watched: false}, + {pattern: 'dist/*.wasm', included: false, watched: false}, ], - proxies: { - '/base/test/ort-wasm.wasm': '/base/dist/ort-wasm.wasm', - '/base/test/ort-wasm-threaded.wasm': '/base/dist/ort-wasm-threaded.wasm', - '/base/test/ort-wasm-simd.wasm': '/base/dist/ort-wasm-simd.wasm', - '/base/test/ort-wasm-simd-threaded.wasm': '/base/dist/ort-wasm-simd-threaded.wasm', - '/base/test/ort-wasm-simd.jsep.wasm': '/base/dist/ort-wasm-simd.jsep.wasm', - '/base/test/ort-wasm-simd-threaded.jsep.wasm': '/base/dist/ort-wasm-simd-threaded.jsep.wasm', - '/base/test/ort-wasm-threaded.worker.js': '/base/dist/ort-wasm-threaded.worker.js', - }, plugins: karmaPlugins, client: {captureConsole: true, mocha: {expose: ['body'], timeout: timeoutMocha}}, - preprocessors: {mainFile: ['sourcemap']}, reporters: ['mocha', 'BrowserStack'], browsers: [], captureTimeout: 120000, diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 18a068e0ced8b..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -2,17 +2,17 @@ // Licensed under the MIT License. /* eslint-disable import/no-internal-modules */ -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function async init(): Promise {} - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { // NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts new file mode 100644 index 0000000000000..475a0243ebd3d --- /dev/null +++ b/js/web/lib/backend-wasm-inference.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts new file mode 100644 index 0000000000000..09dac3a85311c --- /dev/null +++ b/js/web/lib/backend-wasm-training.ts @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index ceb20044d97b6..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common'; -import {cpus} from 'os'; +import {cpus} from 'node:os'; +import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise { // populate wasm flags initializeFlags(); @@ -40,14 +40,14 @@ class OnnxruntimeWebAssemblyBackend implements Backend { // init wasm await initializeWebAssemblyInstance(); } - createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; - createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise; - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): + Promise; + createInferenceSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeWebAssemblySessionHandler(); await handler.loadModel(pathOrBuffer, options); return Promise.resolve(handler); } } - -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 2049b2663ead3..fb714bf5996f1 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -1,35 +1,39 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -/* eslint-disable @typescript-eslint/no-unused-vars, @typescript-eslint/naming-convention */ +/* eslint-disable @typescript-eslint/naming-convention */ /** * The interface BuildDefinitions contains a set of flags which are defined at build time. * - * Those flags are processed in terser for tree shaking to remove unused code. + * Those flags are processed in bundler for tree shaking to remove unused code. * No flags in this file should present in production build. */ interface BuildDefinitions { /** * defines whether to disable the whole WebGL backend in the build. */ - DISABLE_WEBGL: boolean; + readonly DISABLE_WEBGL: boolean; /** * defines whether to disable the whole WebGpu backend in the build. */ - DISABLE_WEBGPU: boolean; + readonly DISABLE_WEBGPU: boolean; /** * defines whether to disable the whole WebAssembly backend in the build. */ - DISABLE_WASM: boolean; + readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. */ - DISABLE_WASM_PROXY: boolean; + readonly DISABLE_WASM_PROXY: boolean; /** * defines whether to disable multi-threading feature in WebAssembly backend in the build. */ - DISABLE_WASM_THREAD: boolean; + readonly DISABLE_WASM_THREAD: boolean; + /** + * defines whether to disable training APIs in WebAssembly backend. + */ + readonly DISABLE_TRAINING: boolean; } -declare let BUILD_DEFS: BuildDefinitions; +declare const BUILD_DEFS: BuildDefinitions; diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d5ed536034f3e..6060271ced156 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -3,10 +3,13 @@ /* eslint-disable @typescript-eslint/no-var-requires, @typescript-eslint/no-require-imports */ // We use "require" instead of "import" here because import statement must be put in top level. Our current code does -// not allow terser to tree-shaking code as expected because some codes are treated as having side effects. -// So we import code inside the if-clause to allow terser remove the code safely. +// not allow bundler to tree-shaking code as expected because some codes are treated as having side effects. +// So we import code inside the if-clause to allow bundler remove the code safely. export * from 'onnxruntime-common'; +import * as ort from 'onnxruntime-common'; +export default ort; + import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; @@ -16,14 +19,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (BUILD_DEFS.DISABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 6608b00471e77..5d47570f267a6 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/onnxjs/backends/backend-webgl.ts b/js/web/lib/onnxjs/backends/backend-webgl.ts index 74716ca0edcb3..21ed7e38b9f86 100644 --- a/js/web/lib/onnxjs/backends/backend-webgl.ts +++ b/js/web/lib/onnxjs/backends/backend-webgl.ts @@ -72,7 +72,9 @@ export class WebGLBackend implements Backend { Logger.setWithEnv(env); - Object.defineProperty(env.webgl, 'context', {value: this.glContext.gl}); + if (!env.webgl.context) { + Object.defineProperty(env.webgl, 'context', {value: this.glContext.gl}); + } Logger.verbose( 'WebGLBackend', diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts index dd3f1b30dfb46..1f2b27c7bdea8 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts @@ -186,7 +186,7 @@ export class CoordsGlslLib extends GlslLib { /** * 1D packed output coordinates. */ - protected getOutputPacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputPacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const packedTexShape = texShape; let source = ''; if (packedTexShape[0] === 1) { @@ -331,7 +331,7 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 1D output coordinates. */ - protected getOutputUnpacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputUnpacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const source = ` int getOutputCoords() { ivec2 resTexRC = ivec2(TexCoords.xy * @@ -641,7 +641,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } let output = 'return outputValue;'; @@ -734,7 +734,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inputLayout.unpackedShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inputLayout.unpackedShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } const source = ` float ${funcName}() { diff --git a/js/web/lib/onnxjs/backends/webgl/inference-handler.ts b/js/web/lib/onnxjs/backends/webgl/inference-handler.ts index afb39e84f0060..0a51ff7c4029e 100644 --- a/js/web/lib/onnxjs/backends/webgl/inference-handler.ts +++ b/js/web/lib/onnxjs/backends/webgl/inference-handler.ts @@ -11,7 +11,7 @@ import {createPackedReshape3DProgramInfoLoader, isReshapeCheap, processDims3D} f import {encodeAsUint8} from './ops/uint8-encode'; import {createUnpackProgramInfoLoader} from './ops/unpack'; import {WebGLSessionHandler} from './session-handler'; -import {Encoder} from './texture-data-encoder'; +import {EncoderUsage} from './texture-data-encoder'; import {calculateTextureWidthAndHeight, createTextureLayoutFromShape, createTextureLayoutFromTextureType} from './texture-layout'; import {Artifact, ProgramInfo, ProgramInfoLoader, TextureData, TextureLayout, TextureType} from './types'; @@ -101,7 +101,7 @@ export class WebGLInferenceHandler implements InferenceHandler { /** * Create a TextureData object from a tensor. - * Usage = Encoder.Usage.UploadOnly. + * Usage = EncoderUsage.UploadOnly. * If a related texture data is found in cache, returns it; * Otherwise: * Creates a new texture layout if not provided; @@ -156,7 +156,7 @@ export class WebGLInferenceHandler implements InferenceHandler { buffer.set(tensor.numberData.subarray(oldOffset, oldOffset + oldRowSize), newOffset); } } - return this.createTextureData(adjustedLayout, tensor.type, buffer, tensor, Encoder.Usage.UploadOnly); + return this.createTextureData(adjustedLayout, tensor.type, buffer, tensor, EncoderUsage.UploadOnly); } } @@ -164,10 +164,10 @@ export class WebGLInferenceHandler implements InferenceHandler { const unpackedTextureLayout = createTextureLayoutFromShape(this.session.layoutStrategy, tensor.dims, 1, [], {reverseWH: true}); const unpackedTextureData = this.createTextureData( - unpackedTextureLayout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly); + unpackedTextureLayout, tensor.type, tensor.numberData, tensor, EncoderUsage.UploadOnly); td = this.pack(unpackedTextureData); } else { - td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly); + td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, EncoderUsage.UploadOnly); } } return td; @@ -175,7 +175,7 @@ export class WebGLInferenceHandler implements InferenceHandler { /** * Create a TextureData object using the given data and bind to the given tensor. - * Usage = Encoder.Usage.UploadOnly. + * Usage = EncoderUsage.UploadOnly. * NOTE: this function is a hack for Conv implementation. should remove this function, after rewriting Conv * implementation by Graph.Transformer * @param dataType the tensor data type @@ -184,12 +184,12 @@ export class WebGLInferenceHandler implements InferenceHandler { */ createTextureDataFromLayoutBindTensor( layout: TextureLayout, dataType: Tensor.DataType, data: Tensor.NumberType, tensor: Tensor): TextureData { - return this.createTextureData(layout, dataType, data, tensor, Encoder.Usage.UploadOnly); + return this.createTextureData(layout, dataType, data, tensor, EncoderUsage.UploadOnly); } private createTextureData( layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor, - usage?: Encoder.Usage): TextureData { + usage?: EncoderUsage): TextureData { Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`); const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage); return this.createTextureDataFromTexture(layout, dataType, texture, tensor); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts index 709f883ae12c9..d0e589a428825 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts @@ -12,7 +12,7 @@ import {getChannels, unpackFromChannel} from './packing-utils'; const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat (packed)', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.packed), cacheHint }); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts index c2b18ef86f814..f85f4032feae1 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts @@ -30,13 +30,13 @@ export const concat: OperatorImplementation = const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.unpacked), cacheHint }); const createUnpackedConcatProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { throw new Error('axis specified for concat doesn\'t match input dimensionality'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts index 54b6ccd1a3685..bb44a20d75f34 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts @@ -30,7 +30,7 @@ const gatherProgramMetadata = { }; const createGatherProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const indexDataShape = inputs[1].dims.slice(); const outputShape = new Array(inputShape.length + indexDataShape.length - 1); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts index f74c35b612665..a1da13ec48d70 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts @@ -15,7 +15,7 @@ const createIm2ColProgramMetadata = (cacheHint: string) => ({ }); const createIm2ColProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, + (_inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, outputShape: readonly number[], attributes: ConvAttributes): ProgramInfo => { const xshape = x.dims; const wshape = w.dims; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts index 1cd5288251433..efc79f686c960 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts @@ -35,7 +35,7 @@ const imageScalerProgramMetadata = { }; const createImageScalerProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): ProgramInfo => { const outputShape = inputs[0].dims.slice(); const rank = outputShape.length; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index fb3c2357ae8fe..0be6d1ba8bcd2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -107,10 +107,10 @@ function getBcastSamplerForMatmul( const rankADiff = outRank - inARank; const rankBDiff = outRank - inBRank; - unpackedACoordsSnippet = inAShape.map((s, i) => `coords.${allGlChannels[i + rankADiff]}`); + unpackedACoordsSnippet = inAShape.map((_s, i) => `coords.${allGlChannels[i + rankADiff]}`); unpackedACoordsSnippet[inARank - 1] = 'i*2'; unpackedACoordsSnippet.join(', '); - unpackedBCoordsSnippet = inBShape.map((s, i) => `coords.${allGlChannels[i + rankBDiff]}`); + unpackedBCoordsSnippet = inBShape.map((_s, i) => `coords.${allGlChannels[i + rankBDiff]}`); unpackedBCoordsSnippet[inBRank - 2] = 'i*2'; unpackedBCoordsSnippet.join(', '); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts index 704128fb4858e..523165f29f852 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts @@ -117,7 +117,7 @@ export function getBiasForMatmul( if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); } const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts index 1a2bc7422c833..c9ea460a6f1fc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts @@ -46,7 +46,7 @@ export const parseReduceAttributes: OperatorInitialization = ( }; const createReduceProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, reduceOp: ReduceOp, + (_handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, _name: string, reduceOp: ReduceOp, reduceProgramMetadata: ProgramMetadata): ProgramInfo => { const outputShape: number[] = []; const iRank = inputs[0].dims.length || 1; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts index 51acf5042d8bd..c2d703ed04fa0 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -4,7 +4,7 @@ import {Tensor} from '../../../tensor'; import {WebGLInferenceHandler} from '../inference-handler'; -export const shape = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { +export const shape = (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; }; @@ -13,4 +13,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Shape requires 1 input.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts index d32a76bbc8628..81fc1b7076fdb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts @@ -42,8 +42,8 @@ export const parseSliceAttributes: OperatorInitialization = (no }; const createSliceProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { - const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes; + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { + const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((_val, i) => i) : attributes.axes; const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); const starts = attributes.starts.map((start, i) => { if (start > input.dims[normalizedAxes[i]] - 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/split.ts b/js/web/lib/onnxjs/backends/webgl/ops/split.ts index d1bd00d47eebd..2ab14563d80e2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/split.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/split.ts @@ -49,13 +49,13 @@ export const parseSplitAttributes: OperatorInitialization = (no }; const getProgramCount = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { + (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); return offsets.length; }; const createSplitProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): ProgramInfo => { const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); const offset = offsets[index]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts index c05286d16f936..2c25b10c5872c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts @@ -11,7 +11,7 @@ export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): const sumProgramMetadata = { name: 'Sum', - inputNames: inputs.map((v, i) => `X${i}`), + inputNames: inputs.map((_v, i) => `X${i}`), inputTypes: new Array(inputs.length).fill(TextureType.unpacked) }; @@ -24,7 +24,7 @@ const createSumProgramInfo = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], sumProgramMetadata: ProgramMetadata): ProgramInfo => { const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const outputShape = inputs[0].dims.slice(); - const sumLine = inputs.map((v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); + const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); const shaderSource = ` void main() { vec4 result = ${sumLine}; @@ -65,4 +65,4 @@ const validateInputs = (inputs: Tensor[]): void => { throw new Error('Input types are not matched.'); } } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts index 42128c7abc48c..1d2cba7d9d75f 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts @@ -22,7 +22,7 @@ export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): }; const createTileProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { + (_handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const outputShape = new Array(inputShape.length); @@ -63,4 +63,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') { throw new Error('Invalid repeat type.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts index 815ff13f1f925..d3e7b3c0823be 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts @@ -36,7 +36,7 @@ export const parseTransposeAttributes: OperatorInitialization createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])}); const createTransposeProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { const inputShape = input.dims; perm = getAdjustedPerm(inputShape, perm); const unpackedOutputShape = getOutputShape(inputShape, perm); diff --git a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts index 6ddd420c6fa81..4b0cf3f037921 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts @@ -11,14 +11,15 @@ export declare namespace Encoder { } export type DataType = keyof DataTypeMap; type DataArrayType = DataTypeMap[DataType]; +} - /* eslint-disable @typescript-eslint/naming-convention */ - export const enum Usage { - Default = 0, - UploadOnly, - Download4BytesAsFloat32, - } +/* eslint-disable @typescript-eslint/naming-convention */ +export const enum EncoderUsage { + Default = 0, + UploadOnly, + Download4BytesAsFloat32, } +/* eslint-enable @typescript-eslint/naming-convention */ /** * Abstraction for mapping data types to texture texlets @@ -81,7 +82,7 @@ export class RedFloat32DataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; @@ -118,7 +119,7 @@ export class RGBAFloatDataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; diff --git a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts index c89ef3d23638d..f8e370747928c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts @@ -105,7 +105,7 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. logShape = logShape.map( - (d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); + (_d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); // Packed texture height is at least 2 (the channel height of a single // texel). @@ -182,7 +182,7 @@ export function parseAxisParam(axis: number|number[], shape: number[]): number[] const rank = shape.length; // Normalize input - axis = axis == null ? shape.map((s, i) => i) : ([] as number[]).concat(axis); + axis = axis == null ? shape.map((_s, i) => i) : ([] as number[]).concat(axis); // Check for valid range assert( diff --git a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts index 36c7fe7603aa0..effb65288dc1c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts @@ -4,7 +4,7 @@ import {Logger, Profiler} from '../../instrument'; import {Tensor} from '../../tensor'; -import {Encoder} from './texture-data-encoder'; +import {Encoder, EncoderUsage} from './texture-data-encoder'; import {TextureLayoutStrategy} from './texture-layout-strategy'; import {TextureData, TextureLayout} from './types'; import {WebGLContext} from './webgl-context'; @@ -39,11 +39,11 @@ export class TextureManager { } } createTextureFromLayout( - dataType: Tensor.DataType, layout: TextureLayout, data?: Tensor.NumberType, usage?: Encoder.Usage) { + dataType: Tensor.DataType, layout: TextureLayout, data?: Tensor.NumberType, usage?: EncoderUsage) { const textureDataType = this.toEncoderType(dataType); const encoder = this.glContext.getEncoder(textureDataType, layout.channels || 1, usage); - if (layout.isPacked && usage === Encoder.Usage.UploadOnly) { + if (layout.isPacked && usage === EncoderUsage.UploadOnly) { throw new Error('not implemented'); } const width = layout.width; @@ -63,7 +63,7 @@ export class TextureManager { if (idleTextures && idleTextures.length > 0) { const texture = idleTextures.pop()!; inUseTextures.push(texture); - if (usage === Encoder.Usage.UploadOnly) { + if (usage === EncoderUsage.UploadOnly) { this.glContext.updateTexture(texture, width, height, encoder, this.toTextureData(dataType, data)!); } return texture; @@ -172,7 +172,7 @@ export class TextureManager { throw new Error(`TensorData type ${dataType} is not supported`); } } - toTextureData(dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { + toTextureData(_dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { if (!data) { return undefined; } diff --git a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts index 2d5b844c57f60..744f206e38334 100644 --- a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts +++ b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts @@ -4,7 +4,7 @@ import {env} from 'onnxruntime-common'; import * as DataEncoders from './texture-data-encoder'; -import {DataEncoder, Encoder} from './texture-data-encoder'; +import {DataEncoder, Encoder, EncoderUsage} from './texture-data-encoder'; import {repeatedTry} from './utils'; export interface FenceContext { @@ -257,14 +257,14 @@ ${shaderSource}`); deleteProgram(program: WebGLProgram): void { this.gl.deleteProgram(program); } - getEncoder(dataType: Encoder.DataType, channels: number, usage: Encoder.Usage = Encoder.Usage.Default): DataEncoder { + getEncoder(dataType: Encoder.DataType, channels: number, usage: EncoderUsage = EncoderUsage.Default): DataEncoder { if (this.version === 2) { return new DataEncoders.RedFloat32DataEncoder(this.gl as WebGL2RenderingContext, channels); } switch (dataType) { case 'float': - if (usage === Encoder.Usage.UploadOnly || this.isRenderFloat32Supported) { + if (usage === EncoderUsage.UploadOnly || this.isRenderFloat32Supported) { return new DataEncoders.RGBAFloatDataEncoder(this.gl, channels); } else { return new DataEncoders.RGBAFloatDataEncoder( diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index b95e639817dbf..5599087ab46f5 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -114,7 +114,7 @@ export class ExecutionPlan { // resolve downstream nodes const downstreamNodes = new Set(); - outputList.forEach((output, i) => { + outputList.forEach((_output, i) => { const j = thisOp.node.outputs[i]; for (const currentDownstreamNodeIndex of graphValues[j].to) { const currentDownstreamNode = graphNodes[currentDownstreamNodeIndex]; diff --git a/js/web/lib/onnxjs/instrument.ts b/js/web/lib/onnxjs/instrument.ts index 4c543cab157d7..4f865503d50ec 100644 --- a/js/web/lib/onnxjs/instrument.ts +++ b/js/web/lib/onnxjs/instrument.ts @@ -176,7 +176,7 @@ function createCategorizedLogger(category: string): Logger.CategorizedLogger { // NOTE: argument 'category' is put the last parameter beacause typescript // doesn't allow optional argument put in front of required argument. This // order is different from a usual logging API. -function logInternal(severity: Logger.Severity, content: string, stack: number, category?: string) { +function logInternal(severity: Logger.Severity, content: string, _stack: number, category?: string) { const config = LOGGER_CONFIG_MAP[category || ''] || LOGGER_CONFIG_MAP['']; if (SEVERITY_VALUE[severity] < SEVERITY_VALUE[config.minimalSeverity]) { return; diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 88% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts index 0b06a7a747a44..47e50aeab673a 100644 --- a/js/web/lib/onnxjs/session-handler.ts +++ b/js/web/lib/onnxjs/session-handler-inference.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Session} from './session'; import {Tensor as OnnxjsTensor} from './tensor'; -export class OnnxjsSessionHandler implements SessionHandler { +export class OnnxjsSessionHandler implements InferenceSessionHandler { constructor(private session: Session) { this.inputNames = this.session.inputNames; this.outputNames = this.session.outputNames; diff --git a/js/web/lib/onnxjs/session.ts b/js/web/lib/onnxjs/session.ts index 790be3c740cd5..cf8793e1f26f5 100644 --- a/js/web/lib/onnxjs/session.ts +++ b/js/web/lib/onnxjs/session.ts @@ -1,8 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {readFile} from 'fs'; -import {promisify} from 'util'; +import {readFile} from 'node:fs/promises'; import {resolveBackend, SessionHandlerType} from './backend'; import {ExecutionPlan} from './execution-plan'; @@ -62,7 +61,7 @@ export class Session { const isOrtFormat = arg.endsWith('.ort'); if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - const buf = await promisify(readFile)(arg); + const buf = await readFile(arg); this.initialize(buf, isOrtFormat); } else { // browser diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 0a76d75e79bbf..d697a8b3138cf 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -967,7 +967,7 @@ export class ReduceUtil { const dims = a.dims.slice(0); // if axes is not set, perform reduce on all axes if (axes.length === 0) { - dims.forEach((d, ind) => axes.push(ind)); + dims.forEach((_d, ind) => axes.push(ind)); } // get a temporary broadcastable output shape const outputDims = ReduceUtil.calcReduceShape(dims, axes, true); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 59da1369e152e..00431a4e86d5b 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import type {Tensor} from 'onnxruntime-common'; + export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; @@ -9,11 +11,8 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; - export interface SessionState { - sessionId: number; - errors: Array>; - } + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; } export interface OrtWasmModule extends EmscriptenModule { @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtFree(stringHandle: number): void; - _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): - number; + _OrtCreateTensor( + dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, + dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; _OrtReleaseTensor(tensorHandle: number): void; + _OrtCreateBinding(sessionHandle: number): number; + _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; + _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; + _OrtClearBoundOutputs(ioBindingHandle: number): void; + _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtRunWithBinding( + sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, + runOptionsHandle: number): Promise; _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; + outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; _OrtCreateSessionOptions( graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, @@ -94,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion @@ -102,17 +115,67 @@ export interface OrtWasmModule extends EmscriptenModule { // #endregion // #region JSEP + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. + * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. + */ jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ _JsepGetNodeName(kernel: number): number; - jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): Promise; - jsepRunPromise?: Promise; + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. + * + * @param sessionId - specify the session ID. + */ + jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5e77a0343b4ee..e2c2bc8deccf4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,14 +1,51 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import {Env, Tensor} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {createView, TensorView} from './tensor-view'; +import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency} from './webgpu/types'; + +const getProgramInputTensorInfoDependencyKey = + (inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => { + if (inputDependencies.length !== inputTensors.length) { + throw new Error(`inputDependencies length ${inputDependencies.length} is not equal to inputTensors length ${ + inputTensors.length}.`); + } + + const inputInfos: string[] = []; + for (let i = 0; i < inputTensors.length; ++i) { + const type = inputTensors[i].dataType; + switch (inputDependencies[i]) { + case 'none': { + inputInfos.push(''); + break; + } + case 'type': { + inputInfos.push(`${type}`); + break; + } + case 'rank': { + const rank = inputTensors[i].dims.length; + inputInfos.push(`${type};${rank}`); + break; + } + case 'dims': { + const dims = inputTensors[i].dims.join(','); + inputInfos.push(`${type};${dims}`); + break; + } + default: + throw new Error(`unsupported input dependency: ${inputDependencies[i]}`); + } + } + + return inputInfos.join('|'); + }; /** * get a unique key representing the program from the program info, input shapes and types. @@ -18,15 +55,19 @@ import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/ * */ const getProgramInfoUniqueKey = - (programInfo: ProgramInfo|ProgramInfoLoader, inputTensors: readonly TensorView[]): string => { + (programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => { // final key format: - // []:||... - const inputInfos = inputTensors.map(tensor => `${tensor.dataType};${tensor.dims.join(',')}`).join('|'); + // []:is1DimensionDispatch:||... let key = programInfo.name; - if (programInfo.cacheHint) { - key += '[' + programInfo.cacheHint + ']'; + if (programInfo.shaderCache?.hint) { + key += '[' + programInfo.shaderCache.hint + ']'; } - key += ':' + inputInfos; + key += ':' + is1DimensionDispatch + + `:${ + getProgramInputTensorInfoDependencyKey( + inputTensors, + programInfo.shaderCache?.inputDependencies ?? + new Array(inputTensors.length).fill('dims'))}`; return key; }; @@ -87,17 +128,22 @@ export class WebGpuBackend { */ kernels: Map unknown) | undefined, unknown]]>; - commandEncoder: GPUCommandEncoder|null = null; - computePassEncoder: GPUComputePassEncoder|null = null; + private commandEncoder: GPUCommandEncoder|null = null; + private computePassEncoder: GPUComputePassEncoder|null = null; pendingDispatchNumber = 0; - supportTimestampQuery = false; - profilingQuerySet: GPUQuerySet; - profilingQueryData: GpuData; - profilingTimeBase?: bigint; + queryData?: GpuData; + querySet?: GPUQuerySet; + querySetCount = 2; + queryTimeBase?: bigint; env: Env; + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. + */ + sessionExternalDataMapping: Map> = new Map(); + async initialize(env: Env): Promise { if (!navigator.gpu) { // WebGPU is not available. @@ -124,11 +170,9 @@ export class WebGpuBackend { }, requiredFeatures, }; - // WebGPU Spec: Timestamp Queries Inside Passes - // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md - if (adapter.features.has('timestamp-query-inside-passes')) { - this.supportTimestampQuery = true; - requiredFeatures.push('timestamp-query-inside-passes' as GPUFeatureName); + + if (adapter.features.has('timestamp-query')) { + requiredFeatures.push('timestamp-query'); } if (adapter.features.has('shader-f16')) { requiredFeatures.push('shader-f16'); @@ -153,21 +197,14 @@ export class WebGpuBackend { } }; - if (this.supportTimestampQuery) { - this.profilingQuerySet = this.device.createQuerySet({ - type: 'timestamp', - count: 2, - }); - } - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); } dispose(): void { - // currently, we do not do anything in this function. In all known use cases, we don't have the requirement to - // actually dispose the WebGpuBackend instance, because it's always used as a singleton. - // - // revisit this place if we get real requirement to dispose the instance. + if (typeof this.querySet !== 'undefined') { + this.querySet.destroy(); + } + this.gpuDataManager.dispose(); } getCommandEncoder(): GPUCommandEncoder { @@ -179,7 +216,22 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { - this.computePassEncoder = this.getCommandEncoder().beginComputePass(); + const computePassDescriptor: GPUComputePassDescriptor = {}; + if (this.isQueryEnabled()) { + if (typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.querySetCount, + }); + } + computePassDescriptor.timestampWrites = { + querySet: this.querySet, + beginningOfPassWriteIndex: 0, + endOfPassWriteIndex: 1, + }; + } + + this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -192,18 +244,27 @@ export class WebGpuBackend { } flush(): void { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (this.commandEncoder) { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } + } + + isQueryEnabled(): boolean { + if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') { + return true; + } else { + return false; + } } /** * run a WebGPU program. - * @param program either a ProgramInfo instance containing metadata including the shader code, or a function that - * can be called and return a ProgramInfo instance - * @param inputs a TensorView array. each element represents a value already exists in GPU. + * @param program a ProgramInfo instance + * @param inputTensorViews a TensorView array. each element represents a value already exists in GPU. * @param outputIndices an indices array. each element can be either -1 (temporary data), -2 (persistent data) or an * index to the kernel's output. * @param createKernelOutput a callback function that create a value to kernel's output with the given index @@ -211,45 +272,36 @@ export class WebGpuBackend { * or persistent (owned by the current kernel) * @returns a TensorView array representing the result. */ - run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly TensorView[], outputIndices: readonly number[], + run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[], createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] { - if (inputs.length !== program.inputTypes.length) { - throw new Error(`Input size must be equal to ${program.inputTypes.length}.`); - } - // create info for inputs const inputDatas: GpuData[] = []; - for (let i = 0; i < inputs.length; ++i) { - const gpuData = this.gpuDataManager.get(inputs[i].data); + for (let i = 0; i < inputTensorViews.length; ++i) { + const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputs[i].data}`); + throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); } inputDatas[i] = gpuData; } - const key = getProgramInfoUniqueKey(program, inputs); - let artifact = this.programManager.getArtifact(key); - const programInfo = artifact ? - artifact.programInfo : - (typeof (program as ProgramInfoLoader).get === 'function' ? (program as ProgramInfoLoader).get() : - (program as ProgramInfo)); + const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); // check output indices - const validatedOutputIndices = outputIndices.length === 0 ? programInfo.outputs.map((_, i) => i) : outputIndices; - if (validatedOutputIndices.length !== programInfo.outputs.length) { - throw new Error(`Output size ${validatedOutputIndices.length} must be equal to ${programInfo.outputs.length}.`); + const validatedOutputIndices = outputIndices.length === 0 ? outputs.map((_, i) => i) : outputIndices; + if (validatedOutputIndices.length !== outputs.length) { + throw new Error(`Output size ${validatedOutputIndices.length} must be equal to ${outputs.length}.`); } // create info for outputs const outputTensorViews: TensorView[] = []; const outputDatas: GpuData[] = []; - for (let i = 0; i < programInfo.outputs.length; ++i) { + for (let i = 0; i < outputs.length; ++i) { // value -1 and -2 are used for creating temporary and persistent outputs. // value -3 is used for placeholder output. So -3, -2, -1 and 0, 1, 2, ... are valid // output indices. see type definition of ComputeContextInputsOutputsMapping for more details. if (!Number.isInteger(validatedOutputIndices[i]) || validatedOutputIndices[i] < -3 || - validatedOutputIndices[i] >= programInfo.outputs.length) { + validatedOutputIndices[i] >= outputs.length) { throw new Error(`Invalid output index: ${validatedOutputIndices[i]}`); } if (validatedOutputIndices[i] === -3) { @@ -258,8 +310,8 @@ export class WebGpuBackend { const isTemporary = validatedOutputIndices[i] === -1; const isPersistent = validatedOutputIndices[i] === -2; const tensorView = (isTemporary || isPersistent) ? - createIntermediateOutput(programInfo.outputs[i].dataType, programInfo.outputs[i].dims) : - createKernelOutput(validatedOutputIndices[i], programInfo.outputs[i].dataType, programInfo.outputs[i].dims); + createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : + createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -279,18 +331,97 @@ export class WebGpuBackend { outputDatas.push(gpuData); } - const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(programInfo.dispatchGroup(inputs)); + // load uniforms + // TODO: add cache for uniform (is it necessary?) + // + let uniformBufferBinding: GPUBindingResource|undefined; + if (programUniforms) { + let currentOffset = 0; + let preLength = 0; + const offsets: number[] = []; + let maxAlignmentOfField = 1; + programUniforms.forEach(v => { + const data = typeof v.data === 'number' ? [v.data] : v.data; + if (data.length === 0) { + return; + } + // https://www.w3.org/TR/WGSL/#alignof + let baseAlignment: number; + switch (data.length) { + case 1: + baseAlignment = 4; + break; + case 2: + baseAlignment = 8; + break; + case 3: + baseAlignment = 16; + break; + case 4: + baseAlignment = 16; + break; + case 5: + baseAlignment = 16; + break; + case 6: + baseAlignment = 16; + break; + default: + throw new Error(`unsupported data length: ${data.length}`); + } + + if (preLength === 5 || preLength === 6) { + baseAlignment = 16; + } + if (baseAlignment > maxAlignmentOfField) { + maxAlignmentOfField = baseAlignment; + } + currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; + preLength = data.length; + offsets.push(currentOffset); + currentOffset += data.length * 4; + }); + + currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField; + const arrayBuffer = new ArrayBuffer(currentOffset); + programUniforms.forEach((v, i) => { + const offset = offsets[i]; + const data = typeof v.data === 'number' ? [v.data] : v.data; + if (v.type === 'int32') { + new Int32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'uint32') { + new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else { + new Float32Array(arrayBuffer, offset, data.length).set(data); + } + }); + + const uniformBufferData = + // eslint-disable-next-line no-bitwise + this.gpuDataManager.create(currentOffset, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM); + this.device.queue.writeBuffer(uniformBufferData.buffer, 0, arrayBuffer, 0, currentOffset); + this.gpuDataManager.release(uniformBufferData.id); + uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer}; + } + + const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup); + const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1; + // get program info + const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch); + let artifact = this.programManager.getArtifact(key); if (!artifact) { - artifact = this.programManager.build(programInfo, normalizedDispatchGroup); + artifact = this.programManager.build(program, normalizedDispatchGroup); this.programManager.setArtifact(key, artifact); } LOG_DEBUG( 'info', - () => `[ProgramManager] run "${programInfo.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ + () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - this.programManager.run(artifact, inputs, inputDatas, outputDatas, normalizedDispatchGroup); + this.programManager.run( + artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup, + uniformBufferBinding); return outputTensorViews; } @@ -304,12 +435,9 @@ export class WebGpuBackend { } async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise { - const arrayBuffer = await this.gpuDataManager.download(gpuDataId); - // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure // the buffer is up-to-date. - const data = getTargetBuffer(); - data.set(new Uint8Array(arrayBuffer, 0, data.byteLength)); + await this.gpuDataManager.download(gpuDataId, getTargetBuffer); } alloc(size: number): number { @@ -372,7 +500,7 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { @@ -387,4 +515,40 @@ export class WebGpuBackend { this.currentKernelId = null; } } + + // #region external buffer + registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number { + let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (!sessionInputOutputMapping) { + sessionInputOutputMapping = new Map(); + this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); + } + + const previousBuffer = sessionInputOutputMapping.get(index); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + sessionInputOutputMapping.set(index, [id, buffer]); + return id; + } + unregisterBuffers(sessionId: number): void { + const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (sessionInputOutputMapping) { + sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + this.sessionExternalDataMapping.delete(sessionId); + } + } + getBuffer(gpuDataId: number): GPUBuffer { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error(`no GPU data for buffer: ${gpuDataId}`); + } + return gpuData.buffer; + } + createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): + () => Promise { + return async () => { + const data = await downloadGpuData(this, gpuBuffer, size); + return createView(data.buffer, type); + }; + } + // #endregion } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 78316cbe1c825..d66357e729d5d 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,14 +3,14 @@ import {Env} from 'onnxruntime-common'; -import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; +import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; -import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; +import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; /* eslint-disable no-bitwise */ @@ -90,8 +90,7 @@ class ComputeContextImpl implements ComputeContext { this.inputs = inputs; } - compute(program: ProgramInfoLoader|ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): - TensorView[] { + compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = inputsOutputsMapping?.inputs?.map(i => typeof i === 'number' ? this.inputs[i] : i) ?? this.inputs; @@ -120,6 +119,11 @@ class ComputeContextImpl implements ComputeContext { this.module.HEAPU32[offset++] = dims[i]; } return this.module._JsepOutput(this.opKernelContext, index, data); + } catch (e) { + throw new Error( + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + + `Error: ${e}`); } finally { this.module.stackRestore(stack); } @@ -138,7 +142,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { init( // backend - {backend}, + backend, // jsepAlloc() (size: number) => backend.alloc(size), @@ -178,13 +182,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, sessionState.errors); + return backend.computeKernel(kernel, context, errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index adba0fb9d022d..ad56b92c1d869 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 92fdd5abc3892..6f3d9a52d9f5d 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -35,7 +35,7 @@ export interface GpuDataManager { /** * copy data from GPU to CPU. */ - download(id: GpuDataId): Promise; + download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise; /** * refresh the buffers that marked for release. @@ -46,6 +46,19 @@ export interface GpuDataManager { */ refreshPendingBuffers(): void; + /** + * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID. + * + * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of + * the external buffer. + */ + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + + /** + * unregister an external buffer for IO Binding. + */ + unregisterExternalBuffer(buffer: GPUBuffer): void; + /** * destroy all gpu buffers. Call this when the session.release is called. */ @@ -62,12 +75,56 @@ interface StorageCacheValue { */ const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; -let guid = 0; +let guid = 1; const createNewGpuDataId = () => guid++; +/** + * exported standard download function. This function is used by the session to download the data from GPU, and also by + * factory to create GPU tensors with the capacity of downloading data from GPU. + * + * @param backend - the WebGPU backend + * @param gpuBuffer - the GPU buffer to download + * @param originalSize - the original size of the data + * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer + * will be created and returned. + */ +export const downloadGpuData = + async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): + Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, bufferSize /* size */ + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } + }; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) - storageCache: Map; + private storageCache: Map; // pending buffers for uploading ( data is unmapped ) private buffersForUploadingPending: GPUBuffer[]; @@ -76,12 +133,19 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable storage buffers for computing. private freeBuffers: Map; + // The reusable uniform buffers + private freeUniformBuffers: Map; + + // The external buffers registered users for IO Binding. + private externalBuffers: Map; constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); + this.freeUniformBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; + this.externalBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -143,6 +207,42 @@ class GpuDataManagerImpl implements GpuDataManager { sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); } + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; + if (previousBuffer) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { + throw new Error('previous buffer is not registered'); + } + if (buffer === previousBuffer) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; + } + this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); + } + + this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + return id; + } + + unregisterExternalBuffer(buffer: GPUBuffer): void { + const id = this.externalBuffers.get(buffer); + if (id !== undefined) { + this.storageCache.delete(id); + this.externalBuffers.delete(buffer); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + } + } + // eslint-disable-next-line no-bitwise create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { const bufferSize = calcNormalizedBufferSize(size); @@ -150,11 +250,15 @@ class GpuDataManagerImpl implements GpuDataManager { let gpuBuffer; // Currently, only storage buffers are reused. // eslint-disable-next-line no-bitwise - if ((usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - let buffers = this.freeBuffers.get(bufferSize); + const isStorage = (usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE; + // eslint-disable-next-line no-bitwise + const isUniform = (usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM; + if (isStorage || isUniform) { + const freeBuffers = isStorage ? this.freeBuffers : this.freeUniformBuffers; + let buffers = freeBuffers.get(bufferSize); if (!buffers) { buffers = []; - this.freeBuffers.set(bufferSize, buffers); + freeBuffers.set(bufferSize, buffers); } if (buffers.length > 0) { gpuBuffer = buffers.pop() as GPUBuffer; @@ -193,31 +297,13 @@ class GpuDataManagerImpl implements GpuDataManager { return cachedData.originalSize; } - async download(id: GpuDataId): Promise { + async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - const bufferSize = calcNormalizedBufferSize(cachedData.originalSize); - const gpuReadBuffer = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - commandEncoder.copyBufferToBuffer( - cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - this.backend.flush(); - - return new Promise((resolve) => { - gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { - const data = gpuReadBuffer.getMappedRange().slice(0); - gpuReadBuffer.destroy(); - resolve(data); - }); - }); + await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer); } refreshPendingBuffers(): void { @@ -231,6 +317,10 @@ class GpuDataManagerImpl implements GpuDataManager { if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); } else { buffer.destroy(); } @@ -244,6 +334,11 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); }); }); + this.freeUniformBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache.forEach((storage) => { storage.gpuData.buffer.destroy(); @@ -251,6 +346,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.storageCache = new Map(); this.freeBuffers = new Map(); + this.freeUniformBuffers = new Map(); } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..80f6e3bc11195 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,10 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; +import {batchNorm} from './ops/batch-norm'; +import {biasAdd} from './ops/bias-add'; +import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; @@ -14,8 +18,10 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; +import {range} from './ops/range'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; @@ -25,6 +31,7 @@ import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; +import {where} from './ops/where'; import {ComputeContext} from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; @@ -42,11 +49,14 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['BatchNormalization', [batchNorm]], + ['BiasAdd', [biasAdd]], + ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], - ['ClipV10', [unaryOps.clipV10]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], @@ -61,6 +71,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], @@ -79,10 +90,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], + ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], ['ReduceMean', [reduceMean, parseReduceAttributes]], @@ -110,4 +123,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index dd4f13e76ee04..a121bf3892a32 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -19,34 +19,21 @@ // // modified to fit the needs of the project -export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; - -export const typeSnippet = (component: number) => { +export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: - return 'f32'; + return dataType; case 2: - return 'vec2'; + return `vec2<${dataType}>`; case 3: - return 'vec3'; + return `vec3<${dataType}>`; case 4: - return 'vec4'; + return `vec4<${dataType}>`; default: throw new Error(`${component}-component is not supported.`); } }; -export const activationFnSnippet = - (activation?: Activation, _hasPreluActivationWeights = false, _packed = false, _coordsLength = 3): string => { - if (!activation) { - return ''; - } - - // TODO: add implementations - return ''; - }; - -export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` +export const biasSnippet = (hasBias: boolean): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 08b1d1f30b233..3638938df7dbe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,24 +21,25 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; +import {getActivationSnippet} from '../fuse-utils'; -import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, - activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4): string => { + attributes: ConvAttributes, innerElementSizeX = 4, innerElementSizeW = 4, innerElementSize = 4, + dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -48,9 +49,9 @@ const conv2dCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return w[row * wShape[3] + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) + colIn];'; case 4: - return 'return w[row * wShape[3] / 4 + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];'; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); } @@ -77,13 +78,13 @@ const conv2dCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]'; - const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readXSnippet = ` - let inChannels = wShape[2]; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let inChannels = i32(uniforms.w_shape[2]); + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -92,12 +93,12 @@ const conv2dCommonSnippet = let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; let xCh = ${col} % inChannels; - var resData = ${typeSnippet(innerElementSizeX)}(0.0); + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { ${coordASnippet} - let xIndex = getIndexFromCoords4D(coord, xShape); + let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape)); ${getXSnippet(innerElementSizeX)} } return resData;`; @@ -107,27 +108,30 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`) : + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : (fitInner && fitBOuter ? ` let col = colIn * ${innerElementSizeX}; ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const resType = typeSnippet(innerElementSize, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -138,12 +142,13 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} - ${biasActivationSnippet(addBias, activation)} + ${biasSnippet(addBias)} + ${applyActivation} setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); } }`; @@ -151,26 +156,22 @@ const conv2dCommonSnippet = }; export const createConv2DMatMulProgramInfo = - (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvAttributes, - outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, - sequentialAccessByThreads: boolean): ProgramInfo => { + (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], dimAOuter: number, + dimBOuter: number, dimInner: number, hasBias: boolean, sequentialAccessByThreads: boolean): ProgramInfo => { const isChannelsLast = attributes.format === 'NHWC'; const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; const batchSize = outputShape[0]; const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || - (outWidth % 4 === 0 && !isChannelsLast)) && - outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), @@ -179,7 +180,7 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; @@ -190,62 +191,73 @@ export const createConv2DMatMulProgramInfo = const fitInner = dimInner % tileInner === 0; const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = + inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` - ]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } - fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + programUniforms.push(...createTensorShapeVariables(outputShape)); return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), - getShaderSource: () => ` - ${utilFunctions} + name: 'Conv2DMatMul', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${declareInputs.join('')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; - //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; - - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2])} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..d425155857e14 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,257 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor-view'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ConvTransposeAttributes} from '../conv-transpose'; +import {getActivationSnippet} from '../fuse-utils'; + +import {biasSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))]; + let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; + let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; + let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${type}(0.0); + } + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { + return ${type}(0.0); + } + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} + return x[getIndexFromCoords4D(coord, vec4(uniforms.x_shape))/${innerElementSize}];`; + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + ${readASnippet} + } + return ${type}(0.0);` : + ` + let col = colIn * ${innerElementSize}; + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + ${readASnippet} + } + return ${type}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : + 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${type}(0.0); + `; + + const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const userCode = ` + ${activationFunction} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleA : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleW : sampleA} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { + let col = colIn * ${innerElementSize}; + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; + ${coordResSnippet} + ${biasSnippet(addBias)} + ${applyActivation} + result[getIndexFromCoords4D(coords, vec4(uniforms.result_shape))/${innerElementSize}] = value; + } + }`; + return userCode; + }; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + programUniforms.push(...createTensorShapeVariables(outputShape)); + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)}; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${ + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..2e6392aada454 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,13 +20,14 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo} from '../../types'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,21 +191,21 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } @@ -238,7 +239,7 @@ const createConvTranspose2DOpProgramShaderSource = }; export const createConvTranspose2DProgramInfo = - (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { const hasBias = inputs.length > 2; // const isChannelsLast = attributes.format === 'NHWC'; @@ -256,15 +257,19 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { - ...metadata, - outputs: [{ - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - gpuDataType: GpuDataType.default - }], - dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + name: 'ConvTranspose2D', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + outputs: [{ + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType + }] + }), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts index 0ba48a33fbc47..6f2c0231104dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -19,13 +19,13 @@ // // modified to fit the needs of the project -export const utilFunctions = ` +export const utilFunctions = (strideStr: string) => (` fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); } fn getOutputIndexFromCoords(coords : vec4) -> i32 { return dot(coords, vec4( - outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1)); + i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1)); } -`; +`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8d43dbb378a69..a8f296ea0c865 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -21,9 +21,9 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -70,8 +70,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = }; export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub: array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub: array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -112,10 +112,10 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; @@ -179,8 +179,9 @@ const readDataFromSubASnippet = (transposeA: boolean) => // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -222,7 +223,7 @@ export const makeMatMulPackedSource = workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; @@ -283,7 +284,7 @@ for (var t = 0; t < numTiles; t = t + 1) { workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][tileCol + inner]; @@ -309,8 +310,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { `; return ` - var mm_Asub : array, ${tileAHight}>; - var mm_Bsub : array, ${tileInner}>; + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; const tileInner = ${tileInner}; @@ -321,10 +322,10 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, rowPerThread>; // Without this initialization strange values show up in acc. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { @@ -338,18 +339,16 @@ fn main(@builtin(local_invocation_id) localId : vec3, }; const matMulReadWriteFnSource = - (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => { - const batchAVariable = variables[0]; - const batchBVariable = variables[1]; - const batchVariable = variables[2]; - const aVariable = variables[3]; - const bVariable = variables[4]; - const outputVariable = variables[5]; - const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); - const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], + batchShapes: Array, isChannelsLast = false): string => { + const [batchAShape, batchBShape, batchShape] = batchShapes; + const [batchVariable, aVariable, bVariable, outputVariable] = variables; + const broadCastADims = getBroadcastDims(batchAShape, batchShape); + const broadCastBDims = getBroadcastDims(batchBShape, batchShape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); const getAIndices = () => { - const aRank = aVariable.shape.length; - const batchRank = batchVariable.shape.length; + const aRank = aVariable.rank; + const batchRank = batchVariable.rank; let resStr = `var aIndices: ${aVariable.type.indices};`; for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; @@ -362,8 +361,8 @@ const matMulReadWriteFnSource = return resStr; }; const getBIndices = () => { - const bRank = bVariable.shape.length; - const batchRank = batchVariable.shape.length; + const bRank = bVariable.rank; + const batchRank = batchVariable.rank; let resStr = `var bIndices: ${bVariable.type.indices};`; for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; @@ -377,10 +376,10 @@ const matMulReadWriteFnSource = }; const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimAOuter && col < dimInner) + if(row < uniforms.dimAOuter && col < uniforms.dimInner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -389,10 +388,10 @@ const matMulReadWriteFnSource = } fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimInner && col < dimBOuter) + if(row < uniforms.dimInner && col < uniforms.dimBOuter) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -400,12 +399,15 @@ const matMulReadWriteFnSource = return value; } - fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) { + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); - ${hasBias ? 'value = value + bias[colIn];' : ''} + ${ + hasBias ? + `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : + '' } ${applyActivation} ${outputVariable.setByIndices('vec3(coords)', 'value')} } @@ -415,25 +417,25 @@ const matMulReadWriteFnSource = }; export const createMatmulProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, - outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => { + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); - const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA); - const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB); - const variables = [batchADims, batchBDims, batchDims]; + const enableBatchUniforms = enableShapesUniforms(outerDims.length); + const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -444,35 +446,83 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); - variables.push(A); - variables.push(B); - variables.push(output); + + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; + const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); + const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; + + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; + const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); + const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; + + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); const inputVariables = [A, B]; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + if (enableBatchUniforms) { + programUniforms.push(...createTensorShapeVariables(outerDims)); + } + if (enableAShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(aShapeTemp)); + } + if (enableBShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(bShapeTemp)); + } + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + const hasBias = inputs.length > 2; - const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components)); + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dimAOuter: i32 = ${dimAOuter}; - const dimBOuter: i32 = ${dimBOuter}; - const dimInner: i32 = ${dimInner}; - ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} ${activationFunction} + ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)} - ${batchDims.impl()}`; + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + `; + // TODO: turn clipMax and clipMin to uniforms. return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + name: 'MatMul', + shaderCache: { + hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + + `${activationAttributes.activation}` + + `${activationAttributes.clipMax}` + + `${activationAttributes.clipMin}` + + `${isVec4}` + + `${hasBias}` + + `${isChannelsLast}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), getShaderSource, - dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index 412e61a3cc0f9..b6c6853c8f222 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -8,7 +8,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext} from '../types'; import {createReduceProgramInfo, ReduceOp} from './reduce'; @@ -27,31 +27,11 @@ export interface ArgMinMaxAttributes extends AttributeWithCacheKey { selectLastIndex: number; } -const createArgMinMaxAttributesFromInputs = - (inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes => - createAttributeWithCacheKey( - {axis: attributes.axis, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex}); - -const createArgMinMaxProgramInfoLoader = - (inputs: readonly TensorView[], name: string, attributes: ArgMinMaxAttributes, reduceOp: ReduceOp): - ProgramInfoLoader => { - const updatedAttributes: ArgMinMaxAttributes = - inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(inputs, attributes); - const cacheHint = updatedAttributes.cacheKey + inputs.map(x => x.dims.toString()).join('_'); - const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint}; - return { - ...metadata, - get: () => createReduceProgramInfo( - metadata, [inputs[0]], reduceOp, [updatedAttributes.axis], DataType.int64, updatedAttributes.keepDims) - }; - }; - - export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { validateInputs(context.inputs); const argMinMaxOp: ReduceOp = (input, output, axes) => { const idxZero = []; - for (let k = 0; k < input.shape.length; k++) { + for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(`inputIndices[${k}] = 0;`); // first element } @@ -65,14 +45,19 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) '', output.setByOffset('global_idx', 'bestIndex') ]; }; - context.compute(createArgMinMaxProgramInfoLoader(context.inputs, 'ArgMin', attributes, argMinMaxOp), {inputs: [0]}); + + context.compute( + createReduceProgramInfo( + 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), + {inputs: [0]}); }; export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { validateInputs(context.inputs); const argMinMaxOp: ReduceOp = (input, output, axes) => { const idxZero = []; - for (let k = 0; k < input.shape.length; k++) { + for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(`inputIndices[${k}] = 0;`); // first element } @@ -86,7 +71,12 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) '', output.setByOffset('global_idx', 'bestIndex') ]; }; - context.compute(createArgMinMaxProgramInfoLoader(context.inputs, 'argMax', attributes, argMinMaxOp), {inputs: [0]}); + + context.compute( + createReduceProgramInfo( + 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), + {inputs: [0]}); }; export const parseArgMinMaxAttributes = (attributes: Record): ArgMinMaxAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000000..e1f2a47301bfb --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts new file mode 100644 index 0000000000000..ec9da2613f406 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface BatchNormAttributes extends AttributeWithCacheKey { + readonly epsilon: number; + readonly momentum: number; + readonly spatial: boolean; + readonly trainingMode: boolean; + readonly format: 'NHWC'|'NCHW'; + readonly outputCount: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => { + if (!inputs || inputs.length !== 5) { + throw new Error('BatchNormalization requires 5 inputs'); + } + + const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => { + const r = expected.length; + if (r !== actual.length) { + throw new Error(`${message}: num dimensions != ${r}`); + } + expected.forEach((v, i) => { + if (v !== actual[i]) { + throw new Error(`${message}: dim[${i}] do not match`); + } + }); + }; + + if (inputs[0].dims.length > 1) { + const shape = attributes.format === 'NHWC' ? + (attributes.spatial ? inputs[0].dims.slice(-1) : + inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) : + inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); + checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); + checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, shape, 'Invalid input var'); + } else { + checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, [1], 'Invalid input B'); + checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, [1], 'Invalid input var'); + } +}; + +const createBatchNormInferenceProgramInfo = + (inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => { + const {epsilon, spatial, format} = attributes; + const yShape = inputs[0].dims; + const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; + const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; + const outputSize = ShapeUtil.size(yShape) / components; + // Only support uniforms for opset version >= 9 (spatial = true). + const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const shapeOrRank = useShapesUniforms ? yShape.length : yShape; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); + const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); + const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); + const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); + // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. + // Otherwise, the shader compilation will fail. + const calcCOffset = (): string => { + let cOffset = ''; + if (spatial) { + cOffset = `let cOffset = ${ + yShape.length === 1 ? '0u' : + format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : + 'outputIndices[1]'};`; + } else { + if (format === 'NCHW') { + cOffset = ` + ${y.indicesSet('outputIndices', '0', '0')} + let cOffset = ${y.indicesToOffset('outputIndices')};`; + } else { + // update C channel. + cOffset = `var cIndices = ${scale.type.indices}(0); + cIndices[0] = outputIndices[${yShape.length - 1}];`; + // update D1 x ... x Dn channels. + for (let i = 1; i < scale.rank; i++) { + cOffset += `cIndices[${i}] = outputIndices[${i}];`; + } + cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; + } + } + return cOffset; + }; + const getInferenceModeShaderSource = (helper: ShaderHelper) => ` + const epsilon = ${epsilon}; + ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} + ${helper.mainStart()} + ${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)}; + ${calcCOffset()} + let scale = ${scale.getByOffset('cOffset')}; + let bias = ${bias.getByOffset('cOffset')}; + let inputMean = ${inputMean.getByOffset('cOffset')}; + let inputVar = ${inputVar.getByOffset('cOffset')}; + let x = ${x.getByOffset('global_idx')}; + let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; + ${y.setByOffset('global_idx', 'value')} + }`; + return { + name: 'BatchNormalization', + shaderCache: { + hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, + inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, + }, + getShaderSource: getInferenceModeShaderSource, + getRunData: () => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(yShape), + ] : + [ + {type: 'uint32', data: outputSize}, + ], + }), + }; + }; + +export const parseBatchNormAttributes = (attributes: Record): BatchNormAttributes => + createAttributeWithCacheKey(attributes as Omit); + +export const batchNorm = (context: ComputeContext, attributes: Record): void => { + const {inputs, outputCount} = context; + const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount}); + if (env.webgpu.validateInputContent) { + validateInputs(inputs, updatedAttributes); + } + if (attributes.trainingMode) { + throw new Error('BatchNormalization trainingMode is not supported yet.'); + } else { + context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts new file mode 100644 index 0000000000000..e2b8412000ef9 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![320, 640, 1280].includes(inputs[0].dims[2])) { + throw new Error('number of channels should be 320, 640 or 1280'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims; + + const channels = inputs[0].dims[2]; + // since channel number can be only 320/640/1280, it's always divisable by 4 + const outputSize = ShapeUtil.size(outputShape) / 4; + + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, outputShape, 4); + const bias = inputVariable('bias', dataType, [channels], 4); + const residual = inputVariable('residual', dataType, outputShape, 4); + const output = outputVariable('output', dataType, outputShape, 4); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const channels = ${channels}u / 4; + ${shaderHelper.declareVariables(input, bias, residual, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let value = ${input.getByOffset('global_idx')} + + ${bias.getByOffset('global_idx % channels')} + ${residual.getByOffset('global_idx')}; + ${output.setByOffset('global_idx', 'value')} + }`; + + return { + name: 'BiasAdd', + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, + }; +}; + +export const biasAdd = (context: ComputeContext): void => { + validateInputs(context.inputs); + context.compute(createBiasAddProgramInfo(context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts new file mode 100644 index 0000000000000..a81a7a8f1df5c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {erfImpl} from './unary-op'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![2560, 5120, 10240].includes(inputs[0].dims[2])) { + throw new Error('hidden state should be 2560, 5120 or 10240'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims.slice(); + outputShape[2] = outputShape[2] / 2; + + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims, 4); + const bias = inputVariable('bias', inputs[0].dataType, [inputs[0].dims[2]], 4); + const output = outputVariable('output', inputs[0].dataType, outputShape, 4); + + const outputSize = ShapeUtil.size(outputShape) / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M_SQRT2 = sqrt(2.0); + const halfChannels = ${inputs[0].dims[2] / 4 / 2}u; + + ${shaderHelper.declareVariables(input, bias, output)} + + ${erfImpl(`vec4<${dataType}>`, dataType)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasIdx = global_idx % halfChannels; + let batchIndex = global_idx / halfChannels; + let inputOffset = biasIdx + batchIndex * halfChannels * 2; + let valueLeft = input[inputOffset] + bias[biasIdx]; + let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels]; + let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1); + + ${output.setByOffset('global_idx', 'valueLeft * geluRight')} + }`; + + return { + name: 'BiasSplitGelu', + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, + }; +}; + +export const biasSplitGelu = (context: ComputeContext): void => { + validateInputs(context.inputs); + context.compute(createBiasSplitGeluProgramInfo(context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 9c05080f7e118..c033c0ba05356 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -4,9 +4,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -17,11 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, additionalImplementation?: string) => { - const outputSize = ShapeUtil.size(dimsOutput); - const vecSize = Math.ceil(outputSize / 4); - + vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, + typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, + additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -33,37 +31,20 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - let broadcastImpl = ''; - const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', typeA, dimsA, 4); - const b = inputVariable('bData', typeB, dimsB, 4); - if (doBroadcast) { - const calcOffsetImpl = (dims: readonly number[]) => { - const strides = ShapeUtil.computeStrides(dims); - const offsets: string[] = []; - for (let i = dims.length - 1; i >= 0; i--) { - const idx = output.indicesGet('outputIndices', i + dimsOutput.length - dims.length); - offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); - } - return offsets.length > 0 ? offsets.join('+') : '0u'; - }; - - broadcastImpl = ` - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsA)}; - } - - fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsB)}; - } - `; - } + const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; + const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; + const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; + const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); + const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); + const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); let assignment: string; if (vectorize) { if (doBroadcast) { const isAOneElement = ShapeUtil.size(dimsA) === 1; const isBOneElement = ShapeUtil.size(dimsB) === 1; + const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; if (isAOneElement || isBOneElement) { assignment = output.setByOffset( 'global_idx', @@ -73,11 +54,18 @@ const createBinaryOpProgramShader = } else { assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; - let offsetA = calcOffsetA(outputIndices); - let offsetB = calcOffsetB(outputIndices); + let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)}; + let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + 'global_idx', + expressionVector( + sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ? + a.getByOffset('offsetA / 4u') : + `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, + sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ? + b.getByOffset('offsetB / 4u') : + `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))} `; } } else { @@ -94,8 +82,8 @@ const createBinaryOpProgramShader = const expressionB = `bData[indexB${x}][componentB${x}]`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = calcOffsetA(outputIndices${x}); - let offsetB${x} = calcOffsetB(outputIndices${x}); + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let componentA${x} = offsetA${x} % 4u; @@ -122,28 +110,28 @@ const createBinaryOpProgramShader = } return ` - ${shaderHelper.declareVariables(a, b, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)} ${additionalImplementation ?? ''} - ${broadcastImpl} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; const createBinaryOpProgramInfo = - (metadata: ProgramMetadata, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, + (name: string, cacheKey: string, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => { const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); let outputShape = a.dims; let outputSize = ShapeUtil.size(a.dims); let vectorize = false; + let sharedDimensionDivisibleBy4 = false; // TODO: deal with zero-sized tensors (eg. dims=[1,0]) - + const cacheKeyAux = [isBroadcast]; if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); if (!calculatedShape) { @@ -153,7 +141,12 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; - + const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; + cacheKeyAux.push(isAOneElement); + cacheKeyAux.push(isBOneElement); + cacheKeyAux.push(aLastDimDivisibleBy4); + cacheKeyAux.push(bLastDimDivisibleBy4); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -165,60 +158,76 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { + if (sharedDimension % 4 === 0) { + sharedDimensionDivisibleBy4 = true; + vectorize = true; + } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { vectorize = true; } } else { // element-wise vectorize = true; } - + cacheKeyAux.push(vectorize); + const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && + enableShapesUniforms(outputShape.length); return { - ...metadata, + name, + shaderCache: { + hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), + inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, additionalImplementation), - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}) + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, + a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b.dims), + ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ], + }), }; }; -const createBinaryOpProgramInfoLoader = - (inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, - cacheKey?: string, outputDataType?: number): ProgramInfoLoader => { - const metadata: - ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey}; - return { - ...metadata, - get: () => createBinaryOpProgramInfo( - metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType) - }; +const runBinaryOp = + (context: ComputeContext, name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, + cacheKey?: string, outputDataType?: number): void => { + context.compute(createBinaryOpProgramInfo( + name, cacheKey ?? '', context.inputs[0], context.inputs[1], funcCall, additionalImplementation, + outputDataType)); }; export const add = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`)); + runBinaryOp(context, 'Add', (a, b) => `${a}+${b}`); }; export const div = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`)); + runBinaryOp(context, 'Div', (a, b) => `${a}/${b}`); }; export const equal = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4(${a}==${b})`}), - undefined, undefined, DataType.bool)); + runBinaryOp( + context, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4(${a}==${b})`}), undefined, + undefined, DataType.bool); }; export const mul = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`)); + runBinaryOp(context, 'Mul', (a, b) => `${a}*${b}`); }; export const pow = (context: ComputeContext): void => { const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value; const roundStr = type === 'i32' ? 'round' : ''; - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'Pow', - ({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}), + runBinaryOp( + context, 'Pow', ({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}), ` fn pow_custom(a : ${type}, b : ${type}) -> ${type} { if (b == ${type}(0.0)) { @@ -233,34 +242,33 @@ export const pow = (context: ComputeContext): void => { // TODO: implement vectorized pow return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w)); } - `)); + `); }; export const sub = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`)); + runBinaryOp(context, 'Sub', (a, b) => `${a}-${b}`); }; export const greater = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4(${a}>${b})`}), - undefined, undefined, DataType.bool)); + runBinaryOp( + context, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4(${a}>${b})`}), undefined, + undefined, DataType.bool); }; export const less = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})`}), - undefined, undefined, DataType.bool)); + runBinaryOp( + context, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})`}), undefined, + undefined, DataType.bool); }; export const greaterOrEqual = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'GreaterOrEqual', - ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), undefined, undefined, - DataType.bool)); + runBinaryOp( + context, 'GreaterOrEqual', ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), + undefined, undefined, DataType.bool); }; export const lessOrEqual = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})`}), - undefined, undefined, DataType.bool)); + runBinaryOp( + context, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})`}), + undefined, undefined, DataType.bool); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0ab777bfbdee9..b7a391ee667bb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -3,6 +3,7 @@ import {DataType} from '../../../wasm-common'; import {ShapeUtil} from '../../util'; +import {ProgramUniform} from '../types'; /** * constant value for a workgroup size. @@ -57,10 +58,11 @@ interface IndicesHelperTypes { * create an instance of an indices helper: * - `inputVariable()`: create an indices helper instance for an input. * - `outputVariable()`: create an indices helper instance for an output. + * - `internalVariable()`: create an indices helper instance for an internal variable. * * An indices helper instance contains helper functions for the following operations: * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an - * input or an output) and `shape`(the passed in shape). + * input, an output or an internal variable) and `shape`(the passed in shape). * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView). * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate @@ -102,6 +104,16 @@ export interface IndicesHelper { */ readonly indicesToOffset: (varIndices: string) => string; + /** + * WGSL code of an `u32` expression for getting original offset from broadcasted indices. + * + * @param varIndices - a `type.indices` expression representing the output indices. + * @param output - output IndicesHelper. + * + * @returns an `u32` expression + */ + readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; + /** * WGSL code of generating an indices literal * @@ -181,14 +193,24 @@ export interface IndicesHelper { readonly name: string; /** - * whether the helper is for an input or an output. + * whether the helper is for an input, an output or an internal variable. + */ + readonly usage: 'input'|'output'|'internal'; + + /** + * the rank of the input or output. + */ + readonly rank: number; + + /** + * a string representing the variable name for the shape of the input or output. */ - readonly usage: 'input'|'output'; + readonly shape: string; /** - * the shape of the input or output. + * a string representing the variable name for the strides of the input or output. */ - readonly shape: readonly number[]; + readonly strides: string; } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { @@ -237,20 +259,88 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; +export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => + dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; + +/** + * A helper function to get maximum vector size for specified data length + * @param size + */ +export const getMaxComponents = (size: number) => { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 === 0) { + return 4; + } else if (size % 2 === 0) { + return 2; + } + + return 1; +}; + +/** + * A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0) + * @param dataType + * @param components + * @param value + */ +export const fillVector = (dataType = 'f32', components?: number, value = '0') => { + if (!components || components === 1) { + return `${dataType}(${value})`; + } + + return `vec${components}<${dataType}>(${value})`; +}; + +/** + * A helper function that casts value or vector to f32 + * @param dataType + * @param components + * @param value + */ +export const castToF32 = (dataType: string, components: number, value: string) => { + if (dataType === 'f32') { + return value; + } + if (components === 1) { + return `f32(${value})`; + } + + return `vec${components}f(${value})`; +}; + +/** + * A helper function that returns scalar or sums all components of a vector + * @param name + * @param components + */ +export const sumVector = (name: string, components: number) => { + if (components === 4) { + return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`; + } else if (components === 2) { + return `(${name}.x + ${name}.y)`; + } else if (components === 3) { + return `(${name}.x + ${name}.y + ${name}.z)`; + } + + return name; +}; + /** * A helper function to get a IndicesHelper for a given input or output. * * @param name - the name of the input or output. * @param tensorType - the tensor type of the input or output. - * @param shape - the tensor shape of the input or output. - * @param isInput - whether the helper is for an input or an output. + * @param shapeOrRank - the tensor shape or the rank of the input or output. + * @param usage - the usage of the indices helper. * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shape: readonly number[], isInput: boolean, + (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], components: 1|2|3|4): IndicesHelper => { - const rank = shape.length; + const useUniform = typeof shapeOrRank === 'number'; + const rank = useUniform ? shapeOrRank : shapeOrRank.length; + const rankIdentity = [...new Array(rank).keys()]; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; const mappedType = getWgslMappedType(tensorType, components); const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; @@ -262,18 +352,21 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, + broadcastedIndicesToOffset: false, set: false, setByIndices: false, get: false, getByIndices: false, }; - const strides = ShapeUtil.computeStrides(shape); + const uniformPrefix = useUniform ? 'uniforms.' : ''; + const shape = `${uniformPrefix}${name}_shape`; + const strides = `${uniformPrefix}${name}_strides`; let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { o2iSnippet += ` - let dim${i} = current / ${strides[i]}u; - let rest${i} = current % ${strides[i]}u; + let dim${i} = current / ${strides}[${i}]; + let rest${i} = current % ${strides}[${i}]; indices[${i}] = dim${i}; current = rest${i}; `; @@ -296,7 +389,7 @@ const createIndicesHelper = const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { - offsets.push(`${strides[i]}u * (indices[${i}])`); + offsets.push(`${strides}[${i}] * (indices[${i}])`); } } @@ -329,6 +422,26 @@ const createIndicesHelper = } }; + const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {}; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + implementationUsed.broadcastedIndicesToOffset = true; + const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; + if (implKey in broadcastedIndicesToOffsetImplementation) { + return `${implKey}(${varIndices})`; + } + const offsets = []; + for (let i = rank - 1; i >= 0; i--) { + const idx = output.indicesGet('outputIndices', i + output.rank - rank); + offsets.push(`${indicesGet(strides, i)} * (${idx} % ${indicesGet(shape, i)})`); + } + broadcastedIndicesToOffsetImplementation[implKey] = + `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + }`; + + return `${implKey}(${varIndices})`; + }; + const setByOffset = (offset: number|string, value: string) => (() => { if (type.storage === type.value) { return `${name}[${offset}]=${value};`; @@ -370,11 +483,11 @@ const createIndicesHelper = }`; const getImplementation = rank < 2 ? '' : (() => { - const params = shape.map((_, i) => `d${i}: u32`).join(', '); - const dims = shape.map((_, i) => `d${i}`).join(', '); + const functionParams = rankIdentity.map(i => `d${i}: u32`).join(', '); + const dimsParams = rankIdentity.map(i => `d${i}`).join(', '); return ` - fn get_${name}(${params}) -> ${valueType} { - return get_${name}ByIndices(${indices(dims)}); + fn get_${name}(${functionParams}) -> ${valueType} { + return get_${name}ByIndices(${indices(dimsParams)}); }`; })(); @@ -413,11 +526,11 @@ const createIndicesHelper = }`; const setImplementation = rank < 2 ? '' : (() => { - const params = shape.map((_, i) => `d${i}: u32`).join(', '); - const dims = shape.map((_, i) => `d${i}`).join(', '); + const functionParams = rankIdentity.map(i => `d${i}: u32`).join(', '); + const dimsParams = rankIdentity.map(i => `d${i}`).join(', '); return ` - fn set_${name}(${params}, value: ${valueType}) { - set_${name}ByIndices(${indices(dims)}, value); + fn set_${name}(${functionParams}, value: ${valueType}) { + set_${name}ByIndices(${indices(dimsParams)}, value); }`; })(); @@ -456,12 +569,19 @@ const createIndicesHelper = const impl = () => { const impls = []; + if (!useUniform) { + impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`); + impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); + } if (implementationUsed.offsetToIndices) { impls.push(offsetToIndicesImplementation); } if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } + if (implementationUsed.broadcastedIndicesToOffset) { + Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); + } if (implementationUsed.set) { impls.push(setImplementation); } @@ -482,6 +602,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, + broadcastedIndicesToOffset, indices, indicesGet, indicesSet, @@ -492,9 +613,11 @@ const createIndicesHelper = getByOffset, getByIndices, // isVec4, - usage: isInput ? 'input' : 'output', + usage, name, - shape + strides, + shape, + rank }; }; @@ -503,26 +626,41 @@ const createIndicesHelper = * * @param name - the name of the input. * @param type - the tensor type of the input. - * @param shape - the tensor shape of the input. + * @param shapeOrRank - the tensor shape or the rank of the input. * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1. * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shape, true, components); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. * * @param name - the name of the output. * @param type - the tensor type of the output. - * @param shape - the tensor shape of the output. - * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1. + * @param shapeOrRank - the tensor shape or the rank of the output. + * @param components - the number of components of the output. available values are 1, 2, 3, 4. default is 1. * @returns an IndicesHelper for the output. */ export const outputVariable = - (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shape, false, components); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'output', components); + +/** + * Create a IndicesHelper for an internal variable. + * + * @param name - the name of the variable. + * @param type - the tensor type of the variable. + * @param shapeOrRank - the tensor shape or the rank of the variable. + * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. + * @returns an IndicesHelper for the variable. + */ +export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); + +export type UniformsArrayType = Array<{name: string; type: string}>; /** * A ShaderHelper is a helper class for generating WGSL code. @@ -572,9 +710,28 @@ export interface ShaderHelper { declareVariables(...variables: IndicesHelper[]): string; /** - * Get additional implementation that needs to be added to the shader source. + * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. + * + * @param name - the name of the uniform. + * @param type - the type of the uniform. + */ + registerUniform(name: string, type: string): ShaderHelper; + + /** + * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. + * + * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and + * `type`. */ - readonly additionalImplementations: string; + registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple internal variables. Can be called multiple times to register multiple + * internal variables. + * + * @param variables - an array of IndicesHelper for the variables. + */ + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -595,11 +752,12 @@ class ShaderHelperImpl implements ShaderHelper { const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : `@builtin(local_invocation_index) local_index : u32, - @builtin(workgroup_id) workgroup_id : vec3`; + @builtin(workgroup_id) workgroup_id : vec3, + @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? 'let global_idx = global_id.x;' : - `let global_idx = (workgroup_id.z * ${this.normalizedDispatchGroup[0] * this.normalizedDispatchGroup[1]}u + - workgroup_id.y * ${this.normalizedDispatchGroup[0]}u + workgroup_id.x) * ${ + `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) @@ -608,27 +766,87 @@ class ShaderHelperImpl implements ShaderHelper { `; } - declareVariable(variable: IndicesHelper, bindingIndex: number): string { - this.indicesHelpers.push(variable); + private appendVariableUniforms(variable: IndicesHelper): void { + if (variable.rank !== 0) { + if (variable.shape.startsWith('uniforms.')) { + this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); + } + if (variable.strides.startsWith('uniforms.')) { + this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); + } + } + } + + private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + if (variable.usage === 'internal') { + throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.'); + } + this.variables.push(variable); + this.appendVariableUniforms(variable); + const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } declareVariables(...variables: IndicesHelper[]): string { - let i = 0; - return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n'); + return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + } + + private registerInternalVariable(variable: IndicesHelper): void { + if (variable.usage !== 'internal') { + throw new Error( + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + } + + this.internalVariables.push(variable); + this.appendVariableUniforms(variable); + } + + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { + variables.forEach(v => this.registerInternalVariable(v)); + return this; + } + + registerUniform(name: string, type: string): ShaderHelper { + this.uniforms.push({name, type}); + return this; } - private indicesHelpers: IndicesHelper[] = []; + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + + private internalVariables: IndicesHelper[] = []; + private variables: IndicesHelper[] = []; + private uniforms: UniformsArrayType = []; + private uniformDeclaration(): string { + if (this.uniforms.length === 0) { + return ''; + } + const uniformSnippets: string[] = []; + for (const {name, type} of this.uniforms) { + uniformSnippets.push(`${name}:${type}`); + } + + return ` + struct Uniforms { ${uniformSnippets.join(', ')} }; + @group(0) @binding(${this.variableIndex}) var uniforms: Uniforms;`; + } + private variableIndex = 0; + + /** + * Get additional implementation that needs to be added to the shader source. + */ get additionalImplementations(): string { - return this.indicesHelpers.map(i => i.impl()).join('\n'); + return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + + this.internalVariables.map(i => i.impl()).join('\n'); } } -export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper => - new ShaderHelperImpl(dispatchGroup); +export const createShaderHelper = (dispatchGroup: [number, number, number]) => new ShaderHelperImpl(dispatchGroup); /** * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40 @@ -653,3 +871,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; + +// TODO: remove this limitation once >4D dims are supported by uniform. +export const enableShapesUniforms = (rank: number): boolean => rank <= 4; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 279632c190ded..43cc4a4c080bd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -33,12 +33,10 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createConcatProgramMetadata = (inputCount: number, cacheHint: string) => - ({name: 'Concat', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint}); - -const calculateInputIndexImpl = (numberOfTensors: number): string => ` +const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { - for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { if (index < sizeInConcatAxis[i]) { return i; } @@ -65,82 +63,107 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); +const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { + throw new Error('axis specified for concat doesn\'t match input dimensionality'); + } + const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + // ensure all of the non-concatenated axes match each other + // calculate the shape of the output tensor while we do that + const outputShape = inputShape.slice(0); + for (let i = 1; i < inputs.length; i++) { + const dataNShape = inputs[i].dims.slice(); + for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { + // add to the placeholder for computing output shape + if (axisIndex === adjustedAxis) { + outputShape[adjustedAxis] += dataNShape[axisIndex]; } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; - } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); - } - } + // ensure all non-cancatenated axes match each other + else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + throw new Error('non concat dimensions must match'); } + } + } - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + const dataType = inputs[0].dataType; + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputShapeOrRanks = []; + const enableInputShapesUniforms = []; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); + inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); + inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); + inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + } + for (let i = 0; i < inputs.length; ++i) { + if (enableInputShapesUniforms[i]) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } + } - let previousSum = 0; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); - const output = outputVariable('output', dataType, outputShape); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} + ${(() => { + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; - return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; - -const createConcatProgramInfoLoader = - (inputs: readonly TensorView[], attributes: ConcatAttributes): ProgramInfoLoader => { - const metadata = createConcatProgramMetadata(inputs.length, attributes.cacheKey); - return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)}; - }; + + return { + name: 'Concat', + shaderCache: {hint: `${axis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; +}; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); - context.compute(createConcatProgramInfoLoader(context.inputs, attributes)); + context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 1b7b7e0b29a25..14482272bad38 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,21 +3,18 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; -const createGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ - name: 'GroupedConv', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], - cacheHint -}); - -const createGroupedConvProgramInfo = - (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvAttributes, +/** + * naive grouped conv implementation, supports 1d/2d conv + * @param squeezeOutputShapeFunction - an optional function to squeeze the output shape, only used in conv1d + */ +export const createGroupedConvProgramInfo = + (inputs: readonly TensorView[], attributes: ConvAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { const hasBias = inputs.length > 2; const processBias = hasBias ? 'value += b[output_channel];' : ''; @@ -25,14 +22,13 @@ const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); - const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); const output = outputVariable('output', inputs[0].dataType, outputShape); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape); const w = inputVariable('w', inputs[1].dataType, wShape); const inputVars = [x, w]; @@ -87,27 +83,15 @@ const createGroupedConvProgramInfo = ${output.setByOffset('global_idx', 'value')} }`; return { - ...metadata, - outputs: [{ - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - gpuDataType: GpuDataType.default - }], + name: 'GroupedConv', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{ + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType + }], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; - -/** - * naive grouped conv implementation, supports 1d/2d conv - * @param squeezeOutputShapeFunction - an optional function to squeeze the output shape, only used in conv1d - */ -export const createGroupedConvProgramInfoLoader = - (inputs: readonly TensorView[], attributes: ConvAttributes, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfoLoader => { - const metadata = createGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey); - return { - ...metadata, - get: () => createGroupedConvProgramInfo(inputs, metadata, attributes, squeezeOutputShapeFunction) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e7d1ddf771650..e880afe09a5d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext} from '../types'; +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -63,7 +64,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -95,9 +96,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -197,41 +200,62 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('ConvTranspose input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('ConvTranspose input(bias) should be float tensor'); - } }; -const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ - name: 'ConvTranspose2D', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], - cacheHint -}); - -const createConvTranspose2DProgramInfoLoader = - (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfoLoader => { - const hasBias = inputs.length === 3; - const metadata = createConvTranspose2DProgramMetadata(hasBias, attributes.cacheKey); - return { - ...metadata, - get: () => createConvTranspose2DProgramInfo(inputs, metadata, attributes, squeezeOutputShapeFunction) - }; - }; +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] +const weightTransposePerm = [2, 3, 1, 0]; const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1) { + context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + createTransposeProgramInfo(inputs[1], weightTransposePerm), + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfo( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; @@ -271,7 +295,7 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri kernelShape = [1].concat(kernelShape); const adjustedAttributes = getAdjustedConvTransposeAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); - context.compute(createConvTranspose2DProgramInfoLoader( + context.compute(createConvTranspose2DProgramInfo( inputs, adjustedAttributes, outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]])); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 95a64e5787841..c7ea0cffe51c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,17 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; -import {createGroupedConvProgramInfoLoader} from './conv-grouped'; -import {createConv2DMatMulProgramInfoLoader} from './conv2d-mm'; +import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; +import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; +import {createGroupedConvProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; -import {createMatmulProgramInfoLoader} from './matmul'; -import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; +import {createTransposeProgramInfo} from './transpose'; export const calculateOutputShape = (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], @@ -42,7 +41,7 @@ export interface ConvAttributes extends InternalActivationAttributes, AttributeW } // for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M] -const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); +const weightTransposeAttribute = [2, 3, 1, 0]; const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttributes): void => { // Refer to the below link for all input checks @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('Conv input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('Conv input(bias) should be float tensor'); - } }; const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { @@ -144,15 +134,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes - const hasBias = inputs.length === 3; // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ - const isChannelsLast = attributes.format === 'NHWC'; - if (!isChannelsLast || attributes.group !== 1) { - context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes)); + if (attributes.group !== 1) { + context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); return; } - // const batchSize = context.inputs[0].dims[0]; + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; @@ -165,57 +154,61 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const batch = outputShape[0]; - const sameSize = - isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID'; + const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && + attributes.pads[0] === 0 && attributes.pads[1] === 0; if (sameSize || (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 && attributes.pads[1] === 0)) { // conv2dByMatMul - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) - }, - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; - if (attributes.wIsConst && !context.kernelCustomData.wT) { - context.kernelCustomData.wT = transposedWeight; - } - + const batch = outputShape[0]; + let xReshaped, wReshaped, matmulOutputShape; const matmulInputs = []; - matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels])); - matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels])); + if (isChannelsLast) { + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + if (sameSize) { + const sharedDim = inputHeight * inputWidth * inputChannels; + xReshaped = inputs[0].reshape([1, batch, sharedDim]); + wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]); + matmulOutputShape = [1, batch, outChannels]; + } else { + xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]); + wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]); + matmulOutputShape = [batch, outHeight * outWidth, outChannels]; + } + matmulInputs.push(xReshaped); + matmulInputs.push(wReshaped); + } else { + xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]); + wReshaped = inputs[1].reshape([1, outChannels, inputChannels]); + matmulOutputShape = [batch, outChannels, outHeight * outWidth]; + matmulInputs.push(wReshaped); + matmulInputs.push(xReshaped); + } if (hasBias) { matmulInputs.push(inputs[2]); } - const matmulOutputShape = [batch, outHeight * outWidth, outChannels]; context.compute( - createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape), + createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), {inputs: matmulInputs}); - return; } // TODO: implement conv2dWithIm2Col() - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) - }, + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; @@ -224,16 +217,15 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.2: prepare reshaped inputs const convInputs = [inputs[0], transposedWeight]; if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convInputs.push(inputs[2]); - } + convInputs.push(inputs[2]); } // STEP.3: compute matmul + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; context.compute( - createConv2DMatMulProgramInfoLoader( + createConv2DMatMulProgramInfo( convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, sequentialAccessByThreads), {inputs: convInputs}); @@ -260,7 +252,7 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { const dilations = [1].concat(attributes.dilations); const kernelShape = [1].concat(attributes.kernelShape); const adjustedAttributes = getAdjustedConvAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); - context.compute(createGroupedConvProgramInfoLoader( + context.compute(createGroupedConvProgramInfo( inputs, adjustedAttributes, outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [])); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts deleted file mode 100644 index 21c0b97042fbb..0000000000000 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import {TensorView} from '../../tensor-view'; -import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; - -import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; -import {ConvAttributes} from './conv'; - - -const createConv2DMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ - name: 'Conv2DMatMul', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], - cacheHint -}); - -export const createConv2DMatMulProgramInfoLoader = - (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], dimAOuter: number, - dimBOuter: number, dimInner: number, hasBias: boolean, sequentialAccessByThreads: boolean): ProgramInfoLoader => { - const metadata = createConv2DMatMulProgramMetadata(hasBias, attributes.cacheKey); - return { - ...metadata, - get: () => createConv2DMatMulProgramInfo( - inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, - sequentialAccessByThreads) - }; - }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index fc9ebf004ad25..a233d37a79e65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -54,7 +54,7 @@ class EinsumTerm { } class EinsumEquation { - constructor(inputs: readonly TensorView[], equation: string) { + constructor(inputs: readonly TensorView[], public readonly equation: string) { this.hasEllipsis = false; this.symbolToInfo = new Map(); this.lhs = new Array(); @@ -177,111 +177,101 @@ class EinsumEquation { outputDims: number[]; // Output dimensions of the equation } // End of class EinsumEquation - -const createEinsumProgramMetadata = (inputCount: number, cacheHint: string): ProgramMetadata => - ({name: 'Einsum', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint}); - -const createEinsumProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { - const dataType = inputs[0].dataType; - const inputVars = new Array(inputs.length); - for (let i = 0; i < inputs.length; ++i) { - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); - } - const outputShape = einsumEquation.outputDims; - const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape); - const idxCopy: string[] = []; - const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); - const initProd = 'var prod = 1.0;'; - const initSum = 'var sum = 0.0;'; - const updateSum = 'sum += prod;'; - const reduceOpsSetIndices: string[] = []; - const reduceOpsLoopHeaders: string[] = []; - const reduceOpsLoopFooters: string[] = []; - const reduceOpCompute: string[] = []; - const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; - einsumEquation.symbolToInfo.forEach((info, symbol) => { - if (rhsSymbols.includes(symbol)) { - const outputIndex = rhsSymbols.indexOf(symbol); - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); - } - indices.forEach((index) => { - idxCopy.push(`${ - inputVars[i].indicesSet( - `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); - }); - } +const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { + const dataType = inputs[0].dataType; + const inputVars = new Array(inputs.length); + for (let i = 0; i < inputs.length; ++i) { + inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + } + const outputShape = einsumEquation.outputDims; + const outputSize = ShapeUtil.size(outputShape); + const output = outputVariable('output', dataType, outputShape); + const idxCopy: string[] = []; + const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (rhsSymbols.includes(symbol)) { + const outputIndex = rhsSymbols.indexOf(symbol); + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push(`${ + inputVars[i].indicesSet(`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); }); - } else { - einsumEquation.lhs.forEach((term, i) => { - const info = einsumEquation.symbolToInfo.get(symbol); - if (info === undefined) { - throw new Error('Invalid symbol error'); - } - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); - } - indices.forEach((index) => { - reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); - }); - reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); - } + } + }); + } else { + einsumEquation.lhs.forEach((term, i) => { + const info = einsumEquation.symbolToInfo.get(symbol); + if (info === undefined) { + throw new Error('Invalid symbol error'); + } + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); }); - reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ - einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); - reduceOpsLoopFooters.push('}'); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); } }); - const reduceOps = isReduceOpsWithoutLoop ? - [ - ...idxCopy, - `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` - ] : - [ - ...idxCopy, - initSum, - ...reduceOpsLoopHeaders, - ...reduceOpsSetIndices, - initProd, - ...reduceOpCompute, - updateSum, - ...reduceOpsLoopFooters, - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` + reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ + einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); + reduceOpsLoopFooters.push('}'); + } + }); + const reduceOps = isReduceOpsWithoutLoop ? + [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` + ] : + [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} ${reduceOps.join('\n')}; ${output.setByOffset('global_idx', 'sum')}; }`; - return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; - -const createEinsumProgramInfoLoader = - (inputs: readonly TensorView[], einsumEquation: EinsumEquation, attributes: EinsumAttributes): - ProgramInfoLoader => { - const metadata = createEinsumProgramMetadata(inputs.length, attributes.cacheKey); - return {...metadata, get: () => createEinsumProgramInfo(metadata, inputs, einsumEquation)}; - }; + return { + name: 'Einsum', + shaderCache: {hint: einsumEquation.equation}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, + }; +}; export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - context.compute(createEinsumProgramInfoLoader(context.inputs, einsumEquation, attributes)); + context.compute(createEinsumProgramInfo(context.inputs, einsumEquation)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 824ce682c0c4b..d998013352d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -3,14 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; - -export const expandProgramMetadata = { - name: 'Expand', - inputTypes: [GpuDataType.default] -}; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -45,21 +40,25 @@ const calculateOutputShape = (inputShape: readonly number[], shape: readonly num (inputShape.length > shape.length) ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape); -const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { +const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { const inputShape = inputs[0].dims; const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); const outputSize = ShapeUtil.size(outputShape); const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const enableInputShapeUniform = enableShapesUniforms(inputShape.length); + const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; + const input = inputVariable('input', dataType, inputShapeOrRank); + const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; var inputIndices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { @@ -73,19 +72,26 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten } ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + if (enableInputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(inputShape)); + } + if (enableOutputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { - ...metadata, + name: 'Expand', + shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, getShaderSource, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }) }; }; export const expand = (context: ComputeContext): void => { validateInputs(context.inputs); - const outputShape = Array.from(context.inputs[1].getBigInt64Array(), Number); - const cacheHint = outputShape.toString(); - context.compute( - {...expandProgramMetadata, cacheHint, get: () => createExpandProgramInfo(expandProgramMetadata, context.inputs)}, - {inputs: [0]}); + context.compute(createExpandProgramInfo(context.inputs), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0e..0b5c0db2b5112 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,17 +10,20 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): + {activationFunction: string; applyActivation: string} => { switch (attributes.activation) { case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + return { + activationFunction: '', + applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` + }; case 'Clip': return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ + attributes.clipMax!});`, applyActivation: 'value = clamp(value, clip_min_, clip_max_);' }; // TODO: adding other activations that can be fused. diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a7d355bc13704..9924a50e2ae6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; @@ -28,7 +28,7 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const createGatherElementsProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => { + (inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputOutputDataType = inputs[0].dataType; const inputRank = inputShape.length; @@ -86,10 +86,13 @@ const createGatherElementsProgramInfo = }`; return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + name: 'GatherElements', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; @@ -99,12 +102,5 @@ export const parseGatherElementsAttributes = (attributes: Record { const inputs = context.inputs; validateInputs(inputs); - - const metadata = { - name: 'GatherElements', - inputTypes: [GpuDataType.default, GpuDataType.default], - cacheHint: attributes.cacheKey, - }; - - context.compute(createGatherElementsProgramInfo(metadata, context.inputs, attributes)); + context.compute(createGatherElementsProgramInfo(context.inputs, attributes)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 47aae13d6799d..5d6d6debadb9a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -18,68 +18,99 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createGatherProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => { - const inputShape = inputs[0].dims; - const indicesShape = inputs[1].dims; - - const inputRank = inputShape.length; - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); - - const outputShape = inputShape.slice(0); - outputShape.splice(axis, 1, ...indicesShape); - - const axisDimLimit = inputShape[axis]; - const outputSize = ShapeUtil.size(outputShape); - - const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); - const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const calcDataIndices = (): string => { - const indicesRank = indicesShape.length; - let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; - for (let i = 0; i < indicesRank; i++) { - calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; - } - calcStr += ` +const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + + const inputRank = inputShape.length; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + + const outputShape = inputShape.slice(0); + outputShape.splice(axis, 1, ...indicesShape); + + const axisDimLimit = inputShape[axis]; + const outputSize = ShapeUtil.size(outputShape); + + const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; + const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); + const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + + const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank); + const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + if (enableInputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + } + if (enableIndicesShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + } + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } + + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + + const calcDataIndices = (): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ + outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`; + } + calcStr += ` var idx = ${indices.getByIndices('indicesIndices')}; if (idx < 0) { - idx = idx + ${axisDimLimit}; + idx = idx + uniforms.axisDimLimit; } var dataIndices = ${data.type.indices}(0); `; - for (let i = 0, j = 0; i < inputRank; i++) { - if (i === axis) { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; - j += indicesRank; - } else { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; - j++; - } - } - return calcStr; - }; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; + j++; + } + } + return calcStr; + }; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(data, indices, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; ${calcDataIndices()}; let value = ${data.getByIndices('dataIndices')}; ${output.setByOffset('global_idx', 'value')}; }`; - return { - ...metadata, - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; + return { + name: 'Gather', + shaderCache: {hint: attributes.cacheKey, inputDependencies}, + getRunData: () => ({ + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType}, + ], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource, + }; +}; export const parseGatherAttributes = (attributes: Record): GatherAttributes => createAttributeWithCacheKey({axis: attributes.axis as number}); @@ -87,12 +118,5 @@ export const parseGatherAttributes = (attributes: Record): Gath export const gather = (context: ComputeContext, attributes: GatherAttributes): void => { const inputs = context.inputs; validateInputs(inputs); - - const metadata = { - name: 'Gather', - inputTypes: [GpuDataType.default, GpuDataType.default], - cacheHint: attributes.cacheKey, - }; - - context.compute(createGatherProgramInfo(metadata, context.inputs, attributes)); + context.compute(createGatherProgramInfo(context.inputs, attributes)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 1a36d4a7545d6..6e9dee41ce488 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -53,39 +53,38 @@ const offsetC = (m: number, n: number, dims: readonly number[]): string => { return offset; }; -const createGemmProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => { - const aShape = inputs[0].dims.slice(); - const bShape = inputs[1].dims.slice(); - const [M, N, K] = GemmUtil.getShapeOfGemmResult( - aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); - const outputShape = [M, N]; - if (!outputShape) { - throw new Error('Can\'t use gemm on the given tensors'); - } - const outputSize = ShapeUtil.size(outputShape); - let line = ''; - if (attributes.transA && attributes.transB) { - line = 'value += a[k * M + m] * b[n * K + k];'; - } else if (attributes.transA && !attributes.transB) { - line = 'value += a[k * M + m] * b[k * N + n];'; - } else if (!attributes.transA && attributes.transB) { - line = 'value += a[m * K + k] * b[n * K + k];'; - } else if (!attributes.transA && !attributes.transB) { - line = 'value += a[m * K + k] * b[k * N + n];'; - } - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; - const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; - const inputStorageBuffersDeclarations = [ - `@group(0) @binding(0) var a : array<${dataType}>;`, - `@group(0) @binding(1) var b : array<${dataType}>;` - ]; - if (inputs.length === 3) { - inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); - } - const getShaderSource = (shaderHelper: ShaderHelper) => ` +const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => { + const aShape = inputs[0].dims.slice(); + const bShape = inputs[1].dims.slice(); + const [M, N, K] = GemmUtil.getShapeOfGemmResult( + aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); + const outputShape = [M, N]; + if (!outputShape) { + throw new Error('Can\'t use gemm on the given tensors'); + } + const outputSize = ShapeUtil.size(outputShape); + let line = ''; + if (attributes.transA && attributes.transB) { + line = 'value += a[k * M + m] * b[n * K + k];'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += a[k * M + m] * b[k * N + n];'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += a[m * K + k] * b[n * K + k];'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += a[m * K + k] * b[k * N + n];'; + } + + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; + const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; + const inputStorageBuffersDeclarations = [ + `@group(0) @binding(0) var a : array<${dataType}>;`, + `@group(0) @binding(1) var b : array<${dataType}>;` + ]; + if (inputs.length === 3) { + inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${M}u; const N: u32 = ${N}u; const K: u32 = ${K}u; @@ -111,28 +110,20 @@ const createGemmProgramInfo = output[global_id.x] = value; }`; - return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; - -const createGemmProgramInfoLoader = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfoLoader => { - const metadata = { + return { name: 'Gemm', - inputTypes: inputs.length === 3 ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], - cacheHint: attributes.cacheKey + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, }; - - return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)}; }; export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => { validateInputs(context.inputs); - context.compute(createGemmProgramInfoLoader(context.inputs, attributes)); + context.compute(createGemmProgramInfo(context.inputs, attributes)); }; export const parseGemmAttributes = (attributes: Record): GemmAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 5a148bda0a9f7..97f633c7cf47e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,20 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; format: 'NHWC'|'NCHW'; } +const metadata = { + name: 'InstanceNormalization' +}; + const createInstanceNormProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { + (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; const outputShape = xShape; @@ -96,89 +101,179 @@ const createInstanceNormProgramInfo = }`; return { ...metadata, - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType}, + ], + dispatchGroup: {x: normCount} + }), getShaderSource, - dispatchGroup: () => ({x: normCount}) }; }; +const computeMean = + (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, + epsilon: number) => { + const components = getMaxComponents(c); + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + + const WG = 64; + // we will store channel scale and channel shift in [2, components] matrix + // or in vec2 when components == 1 + const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const sumCastType = components === 1 ? 'f32' : `vec${components}f`; + const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; + const unitsOfWork = n * c / components; + const wgSize = Math.ceil(h / WG); + + const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${h * c / components}; + + ${shaderHelper.declareVariables(inputHelper)} + @group(0) @binding(1) var output : array<${outputType}>; + + ${shaderHelper.mainStart(WG)} + let currentImageNumber = global_idx / ${WG} / C; + let currentChannelNumber = (global_idx / ${WG}) % C; + let wgId = global_idx % ${WG}; + let wgOffset = wgId * ${wgSize}; + if (wgOffset >= H) { + return; + } + let wgMax = min(wgOffset + ${wgSize}, H); + + let offset = currentImageNumber * imageSize + currentChannelNumber; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = wgOffset; i < wgMax; i++) { + let value = ${sumCastType}(input[offset + i * C]); + sum += value; + squaredSum += value * value; + } + output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; + }`; + + const meanValues = context.compute( + { + name: 'InstanceNormComputeMean', + shaderCache: {hint: JSON.stringify({components, n, h, c})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, WG, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: n * c / components}, + }), + getShaderSource: getMeanShaderSource, + }, + {inputs: [input], outputs: [-1]})[0]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${WG * c / components}; + const epsilon: f32 = ${epsilon}; + + @group(0) @binding(0) var input : array<${outputType}>; + @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; + @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; + @group(0) @binding(3) var output : array<${outputType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} + let currentImageNumber = global_idx / C; + let currentChannelNumber = global_idx % C; + + let offset = currentImageNumber * imageSize; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < ${WG}; i++) { + let value = input[offset + i + currentChannelNumber * ${WG}]; + sum += value[0]; + squaredSum += value[1]; + } + sum = sum / f32(H); + squaredSum = squaredSum / f32(H); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); + let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; + + output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; + }`; + + return context.compute( + { + name: 'InstanceNormComputeChannelScaleShift', + shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; + }; + const createInstanceNormNHWCProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { + (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; + const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + // first compute mean + const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); - const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; const H: u32 = ${H}; - const C: u32 = ${C}; - const normSizeTyped: ${dataType} = ${H}; - const imageSize: u32 = ${H * C}; - const epsilon: f32 = ${attributes.epsilon}; + const C: u32 = ${C / components}; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var scale : array<${dataType}>; - @group(0) @binding(2) var bias : array<${dataType}>; - @group(0) @binding(3) var output : array<${dataType}>; + @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; + @group(0) @binding(1) var scaleInput : array<${scaleType}>; + @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / C; + let currentImageNumber = global_idx / (C * H); let currentChannelNumber = global_idx % C; - // offset is channel num * N - let offset = currentImageNumber * imageSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - - for (var i: u32 = 0u; i < H; i++) { - mean = mean + x[offset + i * C + currentChannelNumber]; - } - mean = mean / normSizeTyped; - - var squaredNorm: ${dataType} = 0; - for (var i: u32 = 0u; i < H; i++) { - let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean; - squaredNorm = squaredNorm + deviation * deviation; - } - let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon); - let channelScale = invStdDev * scale[currentChannelNumber]; - let channelShift = bias[currentChannelNumber] - mean * channelScale; - for (var i: u32 = 0u; i < H; i++) { - let currentOffset = offset + i * C + currentChannelNumber; - output[currentOffset] = x[currentOffset] * channelScale + channelShift; - } + let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scale = scaleInput[scaleOffset]; + output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; - return { - ...metadata, - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)}) - }; + context.compute( + { + name: 'InstanceNormalization', + shaderCache: {hint: `${attributes.cacheKey}`}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, + }, + {inputs: [inputs[0], channelScaleShift]}); }; export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format}); export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { - const metadata = { - name: 'InstanceNormalization', - inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], - cacheHint: attributes.cacheKey, - }; - if (attributes.format === 'NHWC') { - context.compute(createInstanceNormNHWCProgramInfo(metadata, context.inputs, attributes)); + createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); } else { - context.compute(createInstanceNormProgramInfo(metadata, context.inputs, attributes)); + context.compute(createInstanceNormProgramInfo(context.inputs, attributes)); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index d6a79e9460c3f..8a9eeecf2c68d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -18,117 +18,109 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 2) { throw new Error('layerNorm requires at least 2 inputs.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; const createLayerNormProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: LayerNormAttributes, outputCount: number): - ProgramInfo => { - const xShape = inputs[0].dims; - const scale = inputs[1]; - const bias = inputs[2]; - - const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); - const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - - const scaleSize = ShapeUtil.size(scale.dims); - const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; - if (scaleSize !== normSize || (bias && biasSize !== normSize)) { - throw new Error(`Size of X.shape()[axis:] == ${normSize}. + (inputs: readonly TensorView[], attributes: LayerNormAttributes, outputCount: number): ProgramInfo => { + const xShape = inputs[0].dims; + const scale = inputs[1]; + const bias = inputs[2]; + + const outputShape = xShape; + const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); + const normCount = ShapeUtil.sizeToDimension(xShape, axis); + const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + + const scaleSize = ShapeUtil.size(scale.dims); + const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; + if (scaleSize !== normSize || (bias && biasSize !== normSize)) { + throw new Error(`Size of X.shape()[axis:] == ${normSize}. Size of scale and bias (if provided) must match this. Got scale size of ${scaleSize} and bias size of ${biasSize}`); - } - - const meanInvStdDevDim = []; - for (let i = 0; i < xShape.length; ++i) { - if (i < axis) { - meanInvStdDevDim.push(xShape[i]); - } else { - meanInvStdDevDim.push(1); - } - } - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - - const hasMeanDataOutput = outputCount > 1; - const hasInvStdOutput = outputCount > 2; - let bindingIndex = 0; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: u32 = ${normSize}; - const normSizeTyped: ${dataType} = ${normSize}; + } + + const meanInvStdDevDim = []; + for (let i = 0; i < xShape.length; ++i) { + if (i < axis) { + meanInvStdDevDim.push(xShape[i]); + } else { + meanInvStdDevDim.push(1); + } + } + + const components = getMaxComponents(normSize); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + + const hasMeanDataOutput = outputCount > 1; + const hasInvStdOutput = outputCount > 2; + + if (hasMeanDataOutput) { + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const normSize: f32 = ${normSize}; + const normSizeVectorized: u32 = ${normSize / components}; const epsilon: f32 = ${attributes.epsilon}; - @group(0) @binding(${bindingIndex++}) var x : array<${dataType}>; - @group(0) @binding(${bindingIndex++}) var scale : array<${dataType}>; - ${bias ? `@group(0) @binding(${bindingIndex++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingIndex++}) var output : array<${dataType}>; - ${ - hasMeanDataOutput ? - `@group(0) @binding(${bindingIndex++}) var meanDataOutput : array<${dataType}>` : - ''}; - ${ - hasInvStdOutput ? - `@group(0) @binding(${bindingIndex++}) var invStdOutput : array<${dataType}>` : - ''}; - + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} - let offset = global_idx * normSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - var meanSquare: ${dataType} = 0; - - for (var h: u32 = 0u; h < normSize; h++) { - mean = mean + x[h + offset]; - meanSquare = meanSquare + x[h + offset] * x[h + offset]; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} + let offset = global_idx * normSizeVectorized; + var meanVector = ${fillVector('f32', components)}; + var meanSquareVector = ${fillVector('f32', components)}; + + for (var h: u32 = 0u; h < normSizeVectorized; h++) { + let value = ${castToF32(dataType, components, 'x[h + offset]')}; + meanVector += value; + meanSquareVector += value * value; } - mean = mean / normSizeTyped; - meanSquare = sqrt(meanSquare / normSizeTyped - mean * mean + epsilon); - - for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''}; + let mean = ${sumVector('meanVector', components)} / normSize; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / normSize - mean * mean + epsilon); + + for (var j: u32 = 0; j < normSizeVectorized; j++) { + let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; + let f32scale = ${castToF32(dataType, components, 'scale[j]')}; + output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} + ); } ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; - if (hasMeanDataOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); - } - if (hasInvStdOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); - } - - return { - ...metadata, - outputs, - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)}) - }; - }; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (hasMeanDataOutput) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (hasInvStdOutput) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + + return { + name: 'LayerNormalization', + shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), + getShaderSource, + }; + }; export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes => createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon}); export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); - - const metadata = { - name: 'LayerNormalization', - inputTypes: context.inputs.length === 2 ? [GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default, GpuDataType.default], - cacheHint: attributes.cacheKey + context.outputCount.toString(10) + context.inputs.length.toString(10), - }; - - context.compute(createLayerNormProgramInfo(metadata, context.inputs, attributes, context.outputCount)); + context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 837ac8410f291..19ca4ac5358ae 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,31 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; +import {ComputeContext} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {InternalActivationAttributes} from './fuse-utils'; - - -const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ - name: 'MatMul', - inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], - cacheHint -}); - -export const createMatmulProgramInfoLoader = - (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], - reshapedOutputShape?: readonly number[]): ProgramInfoLoader => { - const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); - return { - ...metadata, - get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape, reshapedOutputShape) - }; - }; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -35,10 +15,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; export const matMul = (context: ComputeContext): void => { @@ -47,5 +23,5 @@ export const matMul = (context: ComputeContext): void => { if (!outputShape) { throw new Error('Can\'t use matmul on the given tensors'); } - context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000000..b7726a36bcaad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c2f89fd2845df..18859e253aa02 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -36,8 +36,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const getPadConstant = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[], + dataType: string, constantValue: number): string => { const inputRank = inputDims.length; let block = ''; @@ -66,8 +66,7 @@ const getPadConstant = }; const getPadReflect = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -97,8 +96,7 @@ const getPadReflect = }; const getPadEdge = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -124,8 +122,7 @@ const getPadEdge = }; const getPadWrap = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -151,18 +148,17 @@ const getPadWrap = }; const getPadSnippet = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes, + dataType: string): string => { switch (attributes.mode) { case 0: - return getPadConstant( - output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value); + return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value); case 1: - return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadReflect(output, inputDims, inputStrides, attributes.pads); case 2: - return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadEdge(output, inputDims, inputStrides, attributes.pads); case 3: - return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadWrap(output, inputDims, inputStrides, attributes.pads); default: throw new Error('Invalid mode'); } @@ -179,10 +175,9 @@ const generatePadCode = const output = outputVariable('output', inputs[0].dataType, outputDims); const input = inputVariable('x', inputs[0].dataType, inputDims); - const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType); + const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType); const padCode = ` ${shaderHelper.declareVariables(input, output)} - ${output.impl()} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} @@ -195,21 +190,23 @@ const generatePadCode = return padCode; }; -const createPadProgramInfo = - (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: PadAttributes): ProgramInfo => { - const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); - return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), - dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) - }; - }; +const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => { + const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); + return { + name: 'Pad', + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + }), + getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), + }; +}; const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { if (inputs.length > 1) { const bigInt64Pads = inputs[1].getBigInt64Array(); - const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0; + const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0; const inputRank = inputs[0].dims.length; const updatePads = new Int32Array(2 * inputRank).fill(0); @@ -220,7 +217,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]); } } else { - bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v))); + bigInt64Pads.forEach((v, i) => updatePads[Number(i)] = (Number(v))); } const pads: number[] = []; @@ -232,16 +229,10 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes } }; -const createPadProgramInfoLoader = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfoLoader => { - const updatedAttributes = createPadAttributesFromInputs(inputs, attributes); - const metadata: - ProgramMetadata = {name: 'Pad', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createPadProgramInfo(inputs, metadata, updatedAttributes)}; -}; - export const pad = (context: ComputeContext, attributes: PadAttributes): void => { validateInputs(context.inputs); - context.compute(createPadProgramInfoLoader(context.inputs, attributes), {inputs: [0]}); + const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes); + context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); }; export const parsePadAttributes = (attributes: Record): PadAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 120a0e9de5490..1538644412afd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4) { - throw new Error('Pool ops supports 2-D inputs only for now.'); + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputShapeAsChannelFirst = - isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice(); + const inputShapeAsChannelFirst = input.dims.slice(); + if (isChannelsLast) { + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); @@ -44,22 +46,16 @@ const getAdjustedPoolAttributesAndOutputShape = ( - shaderHelper: ShaderHelper, x: IndicesHelper, outputShape: readonly number[], attributes: AttributeType, - op1: string, op2: string, start: string): string => { + shaderHelper: ShaderHelper, x: IndicesHelper, xShape: readonly number[], outputShape: readonly number[], + attributes: AttributeType, op1: string, op2: string, start: string): string => { const isChannelsLast = attributes.format === 'NHWC'; - const inputDims = x.shape; + const inputDims = xShape; const dataType = x.type.value; const rank = inputDims.length; const outputSize = ShapeUtil.size(outputShape); @@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) { - pad++; - continue; - } - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } else { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } if (attributes.kernelShape.length === 2) { @@ -237,30 +233,32 @@ export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWi } const createAveragePoolProgramInfo = - (input: TensorView, metadata: ProgramMetadata, isGlobalOperator: boolean, attributes: AveragePoolAttributes): - ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); - const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); - - const x = inputVariable('x', input.dataType, input.dims); - const dataType = x.type.value; - - const op1 = 'value += x_val;'; - let op2 = ''; - if (adjustedAttributes.countIncludePad) { - op2 += `value /= ${dataType}(${kernelSize});`; - } else { - op2 += `value /= ${dataType}(${kernelSize} - pad);`; - } - return { - ...metadata, - outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}], - getShaderSource: shaderHelper => - generatePoolingCode(shaderHelper, x, outputShape, adjustedAttributes, op1, op2, '0.0'), - dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) - }; - }; + (name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => { + const [adjustedAttributes, outputShape] = + getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); + const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); + + const x = inputVariable('x', input.dataType, input.dims); + const dataType = x.type.value; + + const op1 = 'value += x_val;'; + let op2 = ''; + if (adjustedAttributes.countIncludePad) { + op2 += `value /= ${dataType}(${kernelSize});`; + } else { + op2 += `value /= ${dataType}(${kernelSize} - pad);`; + } + return { + name, + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: input.dataType}], + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + }), + getShaderSource: shaderHelper => + generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '0.0'), + }; + }; export const parseAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true; @@ -276,9 +274,7 @@ export const parseAveragePoolAttributes = (attributes: Record): export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => { validateInputs(context.inputs); - const metadata = {name: 'AveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; - context.compute( - {...metadata, get: () => createAveragePoolProgramInfo(context.inputs[0], metadata, false, attributes)}); + context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes)); }; const globalPoolAttributes = { @@ -300,9 +296,7 @@ export const parseGlobalAveragePoolAttributes = (attributes: Record { validateInputs(context.inputs); - const metadata = {name: 'GlobalAveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; - context.compute( - {...metadata, get: () => createAveragePoolProgramInfo(context.inputs[0], metadata, true, attributes)}); + context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes)); }; export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { @@ -311,28 +305,29 @@ export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCa } const createMaxPoolProgramInfo = - (input: TensorView, metadata: ProgramMetadata, isGlobalOperator: boolean, attributes: MaxPoolAttributes): - ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); - const op1 = ` + (name: string, input: TensorView, isGlobalOperator: boolean, attributes: MaxPoolAttributes): ProgramInfo => { + const [adjustedAttributes, outputShape] = + getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); + const op1 = ` value = max(x_val, value); `; - const op2 = ''; - const x = inputVariable('x', input.dataType, input.dims); - return { - ...metadata, - outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}], - getShaderSource: shaderHelper => - generatePoolingCode(shaderHelper, x, outputShape, adjustedAttributes, op1, op2, '-1e5'), - dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) - }; - }; + const op2 = ''; + const x = inputVariable('x', input.dataType, input.dims); + return { + name, + shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: input.dataType}], + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + }), + getShaderSource: shaderHelper => + generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '-1e5'), + }; + }; export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { validateInputs(context.inputs); - const metadata = {name: 'MaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; - context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs[0], metadata, false, attributes)}); + context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes)); }; export const parseMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { @@ -358,6 +353,5 @@ export const parseGlobalMaxPoolAttributes = (attributes: Record export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { validateInputs(context.inputs); - const metadata = {name: 'GlobalMaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; - context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs[0], metadata, true, attributes)}); + context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts new file mode 100644 index 0000000000000..9cf66111bf707 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {DataType} from '../../../wasm-common'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {outputVariable, ShaderHelper} from './common'; + +const validateInputsContent = (start: number, limit: number, delta: number): void => { + const sameStartLimit = start === limit; + const increasingRangeNegativeStep = start < limit && delta < 0; + const decreasingRangePositiveStep = start > limit && delta > 0; + + if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { + throw new Error('Range these inputs\' contents are invalid.'); + } +}; + +const createRangeProgramInfo = (start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { + const numElements = Math.abs(Math.ceil((limit - start) / delta)); + const outputShape: number[] = [numElements]; + const outputSize = numElements; + + const output = outputVariable('output', dataType, outputShape); + const wgslType = output.type.storage; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + }`; + return { + name: 'Range', + shaderCache: {hint: [start, limit, delta].map(x => x.toString()).join('_')}, + getShaderSource, + getRunData: () => ( + {outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}}) + }; +}; + +export const range = (context: ComputeContext): void => { + let start = 0; + let limit = 0; + let delta = 0; + if (context.inputs[0].dataType === DataType.int32) { + start = context.inputs[0].getInt32Array()[0]; + limit = context.inputs[1].getInt32Array()[0]; + delta = context.inputs[2].getInt32Array()[0]; + } else if (context.inputs[0].dataType === DataType.float) { + start = context.inputs[0].getFloat32Array()[0]; + limit = context.inputs[1].getFloat32Array()[0]; + delta = context.inputs[2].getFloat32Array()[0]; + } + if (env.webgpu.validateInputContent) { + validateInputsContent(start, limit, delta); + } + + context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), {inputs: []}); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts new file mode 100644 index 0000000000000..1365d1e9a12a4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce'; +import {createTransposeProgramInfo} from './transpose'; + +const reduceOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate * candidate', + logSumExp: 'bestValue + exp(candidate)', + l1: 'bestValue + abs(candidate)', + l2: 'bestValue + candidate * candidate', + logSum: 'bestValue + candidate' +}; + +const reduceSharedOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate', + logSumExp: 'bestValue + candidate', + l1: 'bestValue + candidate', + l2: 'bestValue + candidate', + logSum: 'bestValue + candidate' +}; + +const reduceInitValues: {[key: string]: string} = { + max: '_A[offset]', + min: '_A[offset]', + mean: '0', + sum: '0', + prod: '1', + sumSquare: '0', + logSumExp: '0', + l1: '0', + l2: '0', + logSum: '0' +}; + +const reduceOutputValues: {[key: string]: string} = { + max: 'bestValue', + min: 'bestValue', + sum: 'bestValue', + prod: 'bestValue', + sumSquare: 'bestValue', + logSumExp: 'log(bestValue)', + l1: 'bestValue', + l2: 'sqrt(bestValue)', + logSum: 'log(bestValue)' +}; + +const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => { + const res = []; + for (let i = rank - numInnerAxes; i < rank; ++i) { + res.push(i); + } + return res; +}; + +const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => { + const outputShape = []; + const rank = shape.length; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outputShape.push(shape[dim]); + } + } + const reduceShape = axes.map(dim => shape[dim]); + return [outputShape, reduceShape]; +}; + +const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => { + const rank = shape.length + axes.length; + const expandShape = []; + let shapeIdx = 0; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + expandShape.push(shape[shapeIdx++]); + } else { + expandShape.push(1); + } + } + return expandShape; +}; + +const areAxesInnerMostDims = (axes: number[], rank: number): boolean => { + for (let i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; +}; + +const getAxesPermutation = (axes: number[], rank: number): number[] => { + const res = []; + if (!areAxesInnerMostDims(axes, rank)) { + for (let i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + res.push(i); + } + } + axes.forEach(axis => res.push(axis)); + } + return res; +}; + +export const createReduceSharedProgramInfo = + (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string, + outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => { + const inputShape = inputs[0].dims; + + const outputSize = ShapeUtil.size(outputShape); + const reduceSize = ShapeUtil.size(reduceShape); + + const input = inputVariable('_A', inputs[0].dataType, inputShape); + const output = outputVariable('output', outputDataType, outputShape); + + const workgroupSize = 32; + + const sharedMemorySnippet = ` + var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + `; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)} + ${sharedMemorySnippet} + fn DIV_CEIL(a : u32, b : u32) -> u32 { + return ((a - 1u) / b + 1u); + } + ${shaderHelper.mainStart(workgroupSize)} + let local_idx = local_id.x; + + let outputIndex = global_idx / ${workgroupSize}; + let offset = outputIndex * uniforms.reduceSize; + + var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + let Length = uniforms.reduceSize; + for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { + let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + bestValue = ${reduceOps[reduceType]}; + } + aBestValues[local_idx] = bestValue; + workgroupBarrier(); + + var reduceSize = min(Length, ${workgroupSize}u); + for (var currentSize = reduceSize / 2u; reduceSize > 1u; + currentSize = reduceSize / 2u) { + let interval = DIV_CEIL(reduceSize, 2u); + if (local_idx < currentSize) { + let candidate = aBestValues[local_idx + interval]; + bestValue = ${reduceSharedOps[reduceType]}; + aBestValues[local_idx] = bestValue; + } + reduceSize = interval; + workgroupBarrier(); + } + + if (local_idx == 0u) { + ${ + output.setByOffset( + 'outputIndex', + `${ + reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : + `${reduceOutputValues[reduceType]}`}`)}; + } + }`; + + // One work group is responsible for only one element of output. + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: outputSize}, + programUniforms: [{type: 'uint32', data: reduceSize}] + }), + }; + }; + +const reduceCommon = + (context: ComputeContext, name: string, attributes: ReduceAttributes, + reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => { + const updatedAttributes: ReduceAttributes = + context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); + + let updatedAxes = updatedAttributes.axes; + if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { + updatedAxes = context.inputs[0].dims.map((_dim, i) => i); + } + const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); + + let axes = normalizeAxes; + let input = context.inputs[0]; + const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); + if (permutedAxes.length > 0) { + input = context.compute( + createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0]; + axes = getInnerMostAxes(axes.length, input.dims.length); + } + + const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); + let finalOutputShape = outputShape; + if (updatedAttributes.keepDims) { + finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); + } + + context.compute( + createReduceSharedProgramInfo( + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType, + context.inputs[0].dataType, finalOutputShape, reduceShape), + {inputs: [input]}); + }; + +export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMeanShared', attributes, 'mean'); +}; + +export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL1Shared', attributes, 'l1'); +}; + +export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL2Shared', attributes, 'l2'); +}; + +export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp'); +}; + +export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMaxShared', attributes, 'max'); +}; + +export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMinShared', attributes, 'min'); +}; + +export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceProdShared', attributes, 'prod'); +}; + +export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumShared', attributes, 'sum'); +}; + +export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare'); +}; + +export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum'); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 598b1db033c61..b5c956e57a9b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -5,9 +5,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -31,8 +32,8 @@ export type ReduceOp = const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByOffset('inputOffset')};`, '']; export const createReduceProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], reduceOp: ReduceOp, axesInput: number[], - outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => { + (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp, + axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => { const outputShape: number[] = []; const inputShape = inputs[0].dims; @@ -96,14 +97,17 @@ export const createReduceProgramInfo = }`; return { - ...metadata, + name, + shaderCache, getShaderSource, - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), }; }; -const createReduceAttributesFromInputs = +export const createReduceAttributesFromInputs = (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { const axes: number[] = []; if (inputs[1].dims[0] > 0) { @@ -113,26 +117,22 @@ const createReduceAttributesFromInputs = {axes, keepDims: attributes.keepDims, noopWithEmptyAxes: attributes.noopWithEmptyAxes}); }; -const createReduceProgramInfoLoader = - (inputs: readonly TensorView[], name: string, attributes: ReduceAttributes, - reduceOp: ReduceOp): ProgramInfoLoader => { +const runReduceProgram = + (context: ComputeContext, name: string, attributes: ReduceAttributes, reduceOp: ReduceOp): void => { + const inputs = context.inputs; const updatedAttributes: ReduceAttributes = inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes); - const metadata: ProgramMetadata = { - name, - inputTypes: [GpuDataType.default], - cacheHint: updatedAttributes.cacheKey + '_' + inputs[0].dims.map(d => d.toString()).join(',') - }; - return { - ...metadata, - get: () => createReduceProgramInfo( - metadata, [inputs[0]], - updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, - updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims, updatedAttributes.noopWithEmptyAxes) - }; + + context.compute( + createReduceProgramInfo( + name, {hint: updatedAttributes.cacheKey}, [inputs[0]], + updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, + updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims, + updatedAttributes.noopWithEmptyAxes), + {inputs: [0]}); }; -export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -140,10 +140,10 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut `value += ${input.getByOffset('inputOffset')};`, 'value = log(value);', ]; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceLogSum', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); }; -export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -151,10 +151,10 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): `value += abs(${input.getByOffset('inputOffset')});`, '', ]; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceL1', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceL1', attributes, reduceOp); }; -export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -162,10 +162,10 @@ export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, 'value = sqrt(value);', ]; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceL2', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceL2', attributes, reduceOp); }; -export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -173,15 +173,14 @@ export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttri `value += exp(${input.getByOffset('inputOffset')});`, 'value = log(value);', ]; - context.compute( - createReduceProgramInfoLoader(context.inputs, 'ReduceLogSumExp', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); }; -export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; - for (let k = 0; k < input.shape.length; k++) { + for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(input.indicesSet('inputIndices', k, 0)); } @@ -194,16 +193,17 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes) '', ]; }; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMax', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceMax', attributes, reduceOp); }; -export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output, axes) => { let size = 1.0; - for (let k = 0; k < input.shape.length; k++) { + for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - size *= input.shape[k]; + // TODO: this depends on the input dims. If we want to use uniform, this need to be updated. + size *= context.inputs[0].dims[k]; } } @@ -214,14 +214,14 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes `let value = ${output.type.value}(sum / ${size});`, ]; }; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMean', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceMean', attributes, reduceOp); }; -export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; - for (let k = 0; k < input.shape.length; k++) { + for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { idxZero.push(`inputIndices[${k}] = 0;`); // first element } @@ -234,10 +234,10 @@ export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes) '', ]; }; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMin', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceMin', attributes, reduceOp); }; -export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, @@ -245,10 +245,10 @@ export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes `value *= ${input.getByOffset('inputOffset')};`, '', ]; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceProd', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceProd', attributes, reduceOp); }; -export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -256,10 +256,10 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) `value += ${input.getByOffset('inputOffset')};`, '', ]; - context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceSum', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceSum', attributes, reduceOp); }; -export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -267,8 +267,109 @@ export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttri `t = ${input.getByOffset('inputOffset')}; value += t * t;`, '', ]; - context.compute( - createReduceProgramInfoLoader(context.inputs, 'ReduceSumSquare', attributes, reduceOp), {inputs: [0]}); + runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); +}; + +const useNaiveReduceMethod = + (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { + if (axes.length === 0) { + return noopWithEmptyAxes ? true : false; + } + + let outputSize = 1; + let reduceSize = 1; + for (let dim = 0; dim < axes.length; dim++) { + if (axes.indexOf(dim) === -1) { + outputSize *= shape[dim]; + } else { + reduceSize *= shape[dim]; + } + } + + // The condition data is very rough, although considering the count of Execution Unit (EU), the potential + // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments + // on some machines. + return reduceSize < 32 && outputSize > 1024 ? true : false; + }; + +export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMeanNaive(context, attributes); + } else { + reduceMeanShared(context, attributes); + } +}; + +export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL1Naive(context, attributes); + } else { + reduceL1Shared(context, attributes); + } +}; + +export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL2Naive(context, attributes); + } else { + reduceL2Shared(context, attributes); + } +}; + +export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumExpNaive(context, attributes); + } else { + reduceLogSumExpShared(context, attributes); + } +}; + +export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMaxNaive(context, attributes); + } else { + reduceMaxShared(context, attributes); + } +}; + +export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMinNaive(context, attributes); + } else { + reduceMinShared(context, attributes); + } +}; + +export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceProdNaive(context, attributes); + } else { + reduceProdShared(context, attributes); + } +}; + +export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumNaive(context, attributes); + } else { + reduceSumShared(context, attributes); + } +}; + +export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumSquareNaive(context, attributes); + } else { + reduceSumSquareShared(context, attributes); + } +}; + +export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumNaive(context, attributes); + } else { + reduceLogSumShared(context, attributes); + } }; export const parseReduceAttributes = (attributes: Record): ReduceAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 8b9dbbf57ac75..973a607f9377e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -105,50 +105,51 @@ const validateInputs = } }; -const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string => - 'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\ - lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' + +const getOriginalCoordinateFromResizedCoordinate = + (coordinateTransferMode: CoordinateTransformMode, dType: string): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, + lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { - switch (coordinateTransferMode) { - case 'asymmetric': - return 'return xResized / xScale;'; - case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ + switch (coordinateTransferMode) { + case 'asymmetric': + return 'return xResized / xScale;'; + case 'pytorch_half_pixel': + return 'if (lengthResized > 1) { \ return (xResized + 0.5) / xScale - 0.5; \ } else { \ return 0.0; \ }'; - case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; - case 'align_corners': - return 'if (lengthResized == 1) { \ + case 'tf_half_pixel_for_nn': + return 'return (xResized + 0.5) / xScale;'; + case 'align_corners': + return 'if (lengthResized == 1) { \ return 0.0; \ } else { \ return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ }'; - case 'tf_crop_and_resize': - return 'if (lengthResized > 1) { \ + case 'tf_crop_and_resize': + return `if (lengthResized > 1) { \ return roiStart * (lengthOriginal - 1) + \ (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ } else { \ - return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \ - }'; - case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); - case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; - default: - throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); - } - })() + + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + }`; + case 'half_pixel_symmetric': + return [ + 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', + 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', + 'return offset + ((xResized + 0.5) / xScale) - 0.5;' + ].join('\n'); + case 'half_pixel': + return 'return ((xResized + 0.5) / xScale) - 0.5;'; + default: + throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); + } + })() + '}'; -const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string => - 'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => { +const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { switch (nearestMode) { case 'round_prefer_ceil': return 'if (fract(xOriginal) == 0.5) { \ @@ -218,50 +219,49 @@ const initOutputShape = return outputShape; }; -const adjustOutputShape = - (inputShape: readonly number[], outputShape: readonly number[], scales: number[], attributes: ResizeAttributes): - number[] => { - const scaleInPolicy = (() => { - switch (attributes.keepAspectRatioPolicy) { - case 'not_larger': - return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : - Math.min(...scales, Number.MAX_VALUE); - case 'not_smaller': - return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : - Math.max(...scales, Number.MIN_VALUE); - default: - throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); - } - })(); - scales.fill(1.0, 0, scales.length); - const adjustedOutputShape = inputShape.slice(); - if (attributes.axes.length > 0) { - attributes.axes.forEach((v) => scales[v] = scaleInPolicy); - attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); - } else { - scales.fill(scaleInPolicy, 0, scales.length); - adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); - } - return adjustedOutputShape; - }; +const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => { + const scaleInPolicy = (() => { + switch (attributes.keepAspectRatioPolicy) { + case 'not_larger': + return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : + Math.min(...scales, Number.MAX_VALUE); + case 'not_smaller': + return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : + Math.max(...scales, Number.MIN_VALUE); + default: + throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); + } + })(); + scales.fill(1.0, 0, scales.length); + const adjustedOutputShape = inputShape.slice(); + if (attributes.axes.length > 0) { + attributes.axes.forEach((v) => scales[v] = scaleInPolicy); + attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); + } else { + scales.fill(scaleInPolicy, 0, scales.length); + adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); + } + return adjustedOutputShape; +}; const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array { + fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${ + output.type.value}, ${outputShape.length}> { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array; + const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')}); + const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')}); + var originalIndices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; if (scales[i] == 1.0) { - originalIndices[i] = f32(outputIndex); + originalIndices[i] = ${output.type.value}(outputIndex); } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i], + ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); } } return originalIndices; @@ -273,8 +273,8 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); + const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')}); + const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')}); var inputIndices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; @@ -282,12 +282,13 @@ const calculateInputIndicesFromOutputIndices = if (scales[i] == 1.0) { inputIndex = outputIndex; } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) { + var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i], + ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) { if (original_idx < 0) { inputIndex = 0; - } else if (original_idx > (f32(inputShape[i]) - 1)) { + } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) { inputIndex = inputShape[i] - 1; } else { inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); @@ -314,12 +315,13 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): }`; const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], + useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + const dType = input.type.value; return ` - fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 { + fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var inputIndices: ${input.type.indices}; inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); @@ -330,10 +332,10 @@ const bilinearInterpolation = return input[${input.indicesToOffset('inputIndices')}]; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); - var row:f32 = originalIndices[${heightIdx}]; - var col:f32 = originalIndices[${widthIdx}]; + var row:${dType} = originalIndices[${heightIdx}]; + var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; @@ -350,14 +352,14 @@ const bilinearInterpolation = channel = u32(originalIndices[${channelIdx}]); batch = u32(originalIndices[${batchIdx}]); } - var x11: f32 = getInputValue(batch, channel, row1, col1); - var x12: f32 = getInputValue(batch, channel, row1, col2); - var x21: f32 = getInputValue(batch, channel, row2, col1); - var x22: f32 = getInputValue(batch, channel, row2, col2); - var dx1: f32 = row - f32(row1); - var dx2: f32 = f32(row2 ) - row; - var dy1 = col - f32(col1); - var dy2 = f32(col2) - col; + var x11: ${dType} = getInputValue(batch, channel, row1, col1); + var x12: ${dType} = getInputValue(batch, channel, row1, col2); + var x21: ${dType} = getInputValue(batch, channel, row2, col1); + var x22: ${dType} = getInputValue(batch, channel, row2, col2); + var dx1: ${dType} = row - ${dType}(row1); + var dx2: ${dType} = ${dType}(row2) - row; + var dy1 = col - ${dType}(col1); + var dy2 = ${dType}(col2) - col; return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -367,24 +369,24 @@ const bicubicInterpolation = scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; - + const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ - output.type.indices}) -> f32 { + output.type.indices}) -> ${dType} { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]}, - f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); - var fractOriginalIdx: f32 = originalIdx - floor(originalIdx); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]}, + ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) { return ${extrapolationValue}; } - var data: array = array(0.0, 0.0, 0.0, 0.0); + var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); for (var i: i32 = -1; i < 3; i++) { - var ${direction}: f32 = originalIdx + f32(i); + var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { if (${excludeOutside}) { coefs[i + 1] = 0.0; @@ -407,12 +409,12 @@ const bicubicInterpolation = return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; - fn getCubicInterpolationCoefs(s: f32) -> array { + fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { var absS = abs(s); - var coeffs: array = array(0.0, 0.0, 0.0, 0.0); - var oneMinusAbsS: f32 = 1.0 - absS; - var twoMinusAbsS: f32 = 2.0 - absS; - var onePlusAbsS: f32 = 1.0 + absS; + var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); + var oneMinusAbsS: ${dType} = 1.0 - absS; + var twoMinusAbsS: ${dType} = 2.0 - absS; + var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; @@ -422,12 +424,12 @@ const bicubicInterpolation = return coeffs; } - fn cubicInterpolation1D(x: array, coefs: array) -> f32 { - var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3]; + fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} { + var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var inputIndices: ${input.type.indices} = outputIndices; return colCubicInterpolation(inputIndices, outputIndices); } @@ -435,8 +437,8 @@ const bicubicInterpolation = }; const createResizeProgramInfo = - (metadata: ProgramMetadata, inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, - scalesInput: readonly number[], sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => { + (inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[], + sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => { const inputShape = inputTensor.dims; const roi = updateRoI(roiInput, attributes.axes, inputShape.length); @@ -445,7 +447,7 @@ const createResizeProgramInfo = if (scalesInput.length === 0) { scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value); if (attributes.keepAspectRatioPolicy !== 'stretch') { - outputShape = adjustOutputShape(inputShape, outputShape, scales, attributes); + outputShape = adjustOutputShape(inputShape, scales, attributes); } } const output = outputVariable('output', inputTensor.dataType, outputShape); @@ -453,14 +455,16 @@ const createResizeProgramInfo = const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; + ${noScale ? '' : ` + ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { case 'nearest': return ` ${checkInputIndices(input, inputShape)}; - ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)}; + ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( input, output, inputShape, outputShape, scales, roi, useExtrapolation)}; @@ -470,7 +474,7 @@ const createResizeProgramInfo = ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)}; ${ bilinearInterpolation( - input, output, inputShape, outputShape, scales, useExtrapolation, attributes.extrapolationValue)}; + input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; `; case 'cubic': return ` @@ -483,23 +487,22 @@ const createResizeProgramInfo = throw Error('Invalid resize mode'); } })()}; + `} ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - if (${noScale}) { - output[global_idx] = input[global_idx]; - } else { - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - ${(() => { + ${noScale ? 'output[global_idx] = input[global_idx];' : ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + ${(() => { switch (attributes.mode) { case 'nearest': return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; - } else { - output[global_idx] = ${attributes.extrapolationValue}; - }`; + if (checkInputIndices(inputIndices)) { + output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + } else { + output[global_idx] = ${attributes.extrapolationValue}; + }`; case 'linear': return 'output[global_idx] = bilinearInterpolation(outputIndices);'; case 'cubic': @@ -508,30 +511,20 @@ const createResizeProgramInfo = throw Error(`Unsupported resize mode: ${attributes.mode}`); } })()}; - } + `} }`; return { - ...metadata, - getShaderSource, - outputs: [{dims: outputShape, dataType: inputTensor.dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; - -export const createResizeProgramInfoLoader = - (input: TensorView, attributes: ResizeAttributes, opsetVersion: number, scales: readonly number[], - sizes: readonly number[], roi: readonly number[]): ProgramInfoLoader => { - const metadata: ProgramMetadata = { name: 'Resize', - inputTypes: [GpuDataType.default], - cacheHint: attributes.cacheKey + opsetVersion.toString() + - (scales.length > 0 ? '_scales_' + scales.toString() : '') + - (sizes.length > 0 ? '_sizes_' + sizes.toString() : ''), - }; - return { - ...metadata, - get: () => createResizeProgramInfo(metadata, input, attributes, opsetVersion, scales, sizes, roi) + shaderCache: { + hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ + sizes.length > 0 ? sizes : ''}|${noScale}` + }, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputTensor.dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }) }; }; @@ -549,7 +542,7 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v const opsetVersion = getOpsetVersionFromCustomDataBuffer(context); validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi); context.compute( - createResizeProgramInfoLoader(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]}); + createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]}); }; export const parseResizeAttributes = (attributes: Record): ResizeAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 7bfdd73b8af18..7e500f865c19b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -18,9 +18,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('layerNorm requires at least 3 inputs.'); } - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } const input: TensorView = inputs[0]; const skip: TensorView = inputs[1]; const gamma: TensorView = inputs[2]; @@ -74,98 +71,91 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const createSkipLayerNormProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, - isTraining: boolean): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const outputShape = inputShape; - const outputSize = inputSize; - const hiddenSize = inputShape.slice(-1)[0]; - const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; - const hasBetaInput = inputs.length > 3; - const hasBiasInput = inputs.length > 4; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const hasMeanOutput = isTraining && outputCount > 1; - const hasInvStdDevOutput = isTraining && outputCount > 2; - const hasInputSkipBiasSumOutput = outputCount > 3; - let bindingNumber = 0; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: u32 = ${hiddenSize}; + (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): + ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const outputShape = inputShape; + const outputSize = inputSize; + const hiddenSize = inputShape.slice(-1)[0]; + const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; + const hasBetaInput = inputs.length > 3; + const hasBiasInput = inputs.length > 4; + const hasMeanOutput = isTraining && outputCount > 1; + const hasInvStdDevOutput = isTraining && outputCount > 2; + const hasInputSkipBiasSumOutput = outputCount > 3; + + const components = getMaxComponents(hiddenSize); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const hiddenSize: f32 = ${hiddenSize}; + const hiddenSizeVectorized: u32 = ${hiddenSize / components}; const epsilon: f32 = ${attributes.epsilon}; - @group(0) @binding(${bindingNumber++}) var x : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var skip : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var gamma : array<${dataType}>; - ${hasBetaInput ? `@group(0) @binding(${bindingNumber++}) var beta : array<${dataType}>;` : ''} - ${hasBiasInput ? `@group(0) @binding(${bindingNumber++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingNumber++}) var output : array<${dataType}>; - ${ - hasMeanOutput ? - `@group(0) @binding(${bindingNumber++}) var meanOutput : array<${dataType}>;` : - ''} - ${ - hasInvStdDevOutput ? - `@group(0) @binding(${bindingNumber++}) var invStdOutput : array<${dataType}>;` : - ''} - ${ - hasInputSkipBiasSumOutput ? - `@group(0) @binding(${bindingNumber++}) var inputSkipBiasSum : array<${dataType}>;` : - ''} + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSize; - var sum: f32 = 0.0; - var squareSum: f32 = 0.0; - for (var i: u32 = 0; i < hiddenSize; i++) { + let offset = global_idx * hiddenSizeVectorized; + var sum = ${fillVector('f32', components)}; + var squareSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { let skipValue = skip[offset + i]; let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; let inputValue = x[offset + i]; let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - sum += value; - squareSum += value * value; + let f32Value = ${castToF32(dataType, components, 'value')}; + sum += f32Value; + squareSum += f32Value * f32Value; } - let mean: f32 = sum / f32(hiddenSize); - let variance: f32 = sqrt(squareSum / f32(hiddenSize) - mean * mean + epsilon); + let mean = ${sumVector('sum', components)} / hiddenSize; + let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} - for (var i: u32 = 0; i < hiddenSize; i++) { - output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] + + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; - if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); - } - if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); - } - if (outputCount > 3) { - outputs.push({dims: inputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); - } - - return { - ...metadata, - getShaderSource, - outputs, - dispatchGroup: () => ({x: Math.ceil(outputSize / hiddenSize / 64)}) - }; - }; - -const createSkipLayerNormProgramInfoLoader = - (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): - ProgramInfoLoader => { - const inputTypes = new Array(inputs.length).fill(GpuDataType.default); - const metadata: ProgramMetadata = { - name: 'SkipLayerNormalization', - inputTypes, - cacheHint: attributes.cacheKey, - }; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (outputCount > 1) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 2) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 3) { + outputs.push({dims: inputShape, dataType: inputs[0].dataType}); + } + return { - ...metadata, - get: () => createSkipLayerNormProgramInfo(metadata, inputs, attributes, outputCount, isTraining) + name: 'SkipLayerNormalization', + shaderCache: {hint: attributes.cacheKey}, + getShaderSource, + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), }; }; @@ -186,7 +176,7 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm outputs.push(3); } context.compute( - createSkipLayerNormProgramInfoLoader(context.inputs, attributes, context.outputCount, isTraining), {outputs}); + createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 257b9ebc1fdac..7458579bf4340 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,118 +77,147 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], + enableInputShapeUniforms: boolean): string => + `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { + let input_shape_i = ${ + enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; + let steps_i = ${ + enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; + let signs_i = ${ + enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; + let starts_i = ${ + enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + var inputIndex = outputIndex * steps_i + starts_i + carry; + carry = inputIndex / input_shape_i; + inputIndex = inputIndex % input_shape_i; + if (signs_i < 0) { + inputIndex = input_shape_i - inputIndex - 1u + starts_i; } ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; } return inputIndices; }`; -const createSliceProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) : - [...Array(inputShape.length).keys()]; - let steps = readInput(inputs, 4); - steps.forEach((step) => step !== 0 || (() => { - throw new Error('step cannot be 0'); - })); - if (steps.length === 0) { - steps = Array(axes.length).fill(1); - } - const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps)); +const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) : + [...Array(inputShape.length).keys()]; + let steps = readInput(inputs, 4); + steps.forEach((step) => step !== 0 || (() => { + throw new Error('step cannot be 0'); + })); + if (steps.length === 0) { + steps = Array(axes.length).fill(1); + } + const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps)); - const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); - if (axes.length !== inputShape.length) { - for (let i = 0; i < inputShape.length; ++i) { - if (!axes.includes(i)) { - starts.splice(i, 0, 0); - ends.splice(i, 0, inputShape[i]); - steps.splice(i, 0, 1); - } - } - } - const signs = steps.map(step => Math.sign(step)); - // Convert negative steps to positive steps and reverse starts and ends - steps.forEach((step, i, array) => { - if (step < 0) { - const numSteps = (ends[i] - starts[i]) / step; - const newEnd = starts[i]; - const newStart = newEnd + numSteps * steps[i]; - starts[i] = newStart; - ends[i] = newEnd; - array[i] = -step; - } - }); - - const outputShape = inputShape.slice(0); - axes.forEach((axis, _) => { - outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); - }); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } - const outputTensorInfo: - TensorInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}; + if (axes.length !== inputShape.length) { + for (let i = 0; i < inputShape.length; ++i) { + if (!axes.includes(i)) { + starts.splice(i, 0, 0); + ends.splice(i, 0, inputShape[i]); + steps.splice(i, 0, 1); + } + } + } + const signs = steps.map(step => Math.sign(step)); + // Convert negative steps to positive steps and reverse starts and ends + steps.forEach((step, i, array) => { + if (step < 0) { + const numSteps = (ends[i] - starts[i]) / step; + const newEnd = starts[i]; + const newStart = newEnd + numSteps * steps[i]; + starts[i] = newStart; + ends[i] = newEnd; + array[i] = -step; + } + }); + // Output rank is expected to be less than or equal to the input rank. + const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); - const outputSize = ShapeUtil.size(outputShape); + const outputShape = inputShape.slice(0); + axes.forEach((axis, _) => { + outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); + }); + const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; + + const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; + + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = []; + const uniforms: UniformsArrayType = []; + if (enableShapeUniforms) { + uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); + uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); + uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); + programUniforms.push({type: 'uint32', data: starts}); + programUniforms.push({type: 'int32', data: signs}); + programUniforms.push({type: 'uint32', data: steps}); + } + uniforms.push({name: 'outputSize', type: 'u32'}); + programUniforms.push({type: 'uint32', data: outputSize}); + if (enableShapeUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + } - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${enableShapeUniforms ? '' : [ + `const signs = array(${signs.map(i => `${i}i`).join(',')});`, + `const starts = array(${starts.map(i => `${i}u`).join(',')});`, + `const steps = array(${steps.map(i => `${i}u`).join(',')});`, + `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` + ].join('\n')} - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; let inputIndices = calculateInputIndices(outputIndices); ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; - return { - ...metadata, - getShaderSource, - outputs: [outputTensorInfo], - dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)}) - }; - }; - -const createSliceProgramInfoLoader = - (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfoLoader => { - const updatedAttributes = createSliceAttributesFromInputs(inputs, attributes); - const metadata: ProgramMetadata = { - name: 'Slice', - inputTypes: [GpuDataType.default], - cacheHint: updatedAttributes.cacheKey + (inputs.length > 4 ? 'steps_' + inputs[4].dims.toString() : '') - }; - return {...metadata, get: () => createSliceProgramInfo(metadata, inputs, updatedAttributes)}; - }; + return { + name: 'Slice', + shaderCache: { + hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : + `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, + inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] + }, + getShaderSource, + getRunData: () => ({ + outputs: [outputTensorInfo], + dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms + }) + }; +}; export const slice = (context: ComputeContext, attributes: SliceAttributes): void => { validateInputs(context.inputs, attributes); - const programInfoLoader = createSliceProgramInfoLoader(context.inputs, attributes); - const program = programInfoLoader.get(); - if (ShapeUtil.size(program.outputs[0].dims) > 0) { - context.compute(programInfoLoader, {inputs: [0]}); - } else { - // TODO: support empty output - throw new Error('slice: output size is 0'); - } + const updatedAttributes = createSliceAttributesFromInputs(context.inputs, attributes); + context.compute(createSliceProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); + // if (ShapeUtil.size(program.outputs[0].dims) > 0) { + // context.compute(programInfoLoader, {inputs: [0]}); + // } else { + // // TODO: support empty output + // throw new Error('slice: output size is 0'); + // } }; export const parseSliceAttributes = (attributes: Record): SliceAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 495a4bcea4f47..378a7e738dac9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -8,9 +8,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -22,14 +22,7 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey { readonly axis: number; } -export const softmaxProgramMetadata = { - name: 'Softmax', - inputTypes: [GpuDataType.default] -}; - - const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => { - const dataType = tensorTypeToWsglStorageType(input.dataType); const shape = input.dims; const outputSize = ShapeUtil.size(shape); const WG = 64; @@ -43,35 +36,49 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; - + const components = getMaxComponents(cols); + const packedCols = cols / components; + + const maxVector = (name: string, components: number) => { + if (components === 4) { + return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; + } else if (components === 2) { + return `max(${name}.x, ${name}.y)`; + } else if (components === 3) { + return `max(max(${name}.x, ${name}.y), ${name}.z)`; + } + + return name; + }; + const x = inputVariable('x', input.dataType, input.dims, components); + const output = outputVariable('result', input.dataType, input.dims, components); + const valueType = x.type.value; // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; - const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; - - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var result : array<${dataType}>; - - fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + const threadMaxDecl = tensorTypeToWsglStorageType(input.dataType) === 'f32' ? + `var threadMax = ${valueType}(-3.402823e+38f);` : + `var threadMax = ${valueType}(-65504.0h);`; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + var rowMaxShared : ${valueType}; + var rowSumShared : ${valueType}; + var threadShared : array<${valueType}, ${WG}>; + + fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} { let index = row * row_stride + col; return x[index]; } - fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) { let index = row * row_stride + col; result[index] = value; } - - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(local_invocation_id) local_id : vec3, @builtin(global_invocation_id) global_id : vec3u) { + ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} + ${shaderHelper.mainStart()} let gindex = i32(global_id.x); let lindex = i32(local_id.x); const wg = ${WG}; let row = gindex / wg; - let cols = ${cols}; - let row_stride : i32 = ${cols}; + let cols = uniforms.packedCols; + let row_stride : i32 = uniforms.packedCols; // find the rows max ${threadMaxDecl} @@ -93,12 +100,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowMaxShared = threadShared[0]; + rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)}); } workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum = ${valueType}(0.0); for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; @@ -113,7 +120,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowSumShared = threadShared[0]; + rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)}); } workgroupBarrier(); @@ -124,21 +131,20 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut } }`; return { - ...softmaxProgramMetadata, - outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + name: 'Softmax', + shaderCache: {hint: `${components}`, inputDependencies: ['type']}, + getRunData: () => ({ + outputs: [{dims: shape, dataType: input.dataType}], + dispatchGroup: {x: rows}, + programUniforms: [{type: 'uint32', data: packedCols}] + }), getShaderSource, - dispatchGroup: () => ({x: rows}) }; }; - export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => { validateInputs(context.inputs); - context.compute({ - ...softmaxProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createSoftmaxProgramInfo(context.inputs[0], attributes) - }); + context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes)); }; export const parseSoftmaxAttributes = (attributes: Record): SoftmaxAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 3367091bbac23..fd60d81b87ae1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -61,31 +61,30 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -const createSplitProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const dataType = inputs[0].dataType; - const rank = inputShape.length; - const axis = attributes.axis; - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); - const sizeInConcatAxis = new Array(attributes.numOutputs); - const outputsTensorInfo: TensorInfo[] = []; - const outputShapes: number[][] = []; - let previousSum = 0; - for (let i = 0; i < attributes.numOutputs; i++) { - previousSum += attributes.splitSizes[i]; - sizeInConcatAxis[i] = previousSum; - const outputShape = inputShape.slice(); - outputShape[attributes.axis] = attributes.splitSizes[i]; - outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]); - outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); - } - const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; - const getShaderSource = (shaderHelper: ShaderHelper) => ` +const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const dataType = inputs[0].dataType; + const rank = inputShape.length; + const axis = attributes.axis; + const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + const outputs = new Array(attributes.numOutputs); + const input = inputVariable('input', dataType, inputShape); + const sizeInConcatAxis = new Array(attributes.numOutputs); + const outputsTensorInfo: TensorInfo[] = []; + const outputShapes: number[][] = []; + let previousSum = 0; + for (let i = 0; i < attributes.numOutputs; i++) { + previousSum += attributes.splitSizes[i]; + sizeInConcatAxis[i] = previousSum; + const outputShape = inputShape.slice(); + outputShape[attributes.axis] = attributes.splitSizes[i]; + outputShapes.push(outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]); + outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); + } + const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, ...outputs)} const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateOutputIndexImpl(sizeInConcatAxis.length)} @@ -101,25 +100,22 @@ const createSplitProgramInfo = } writeBufferData(outputNumber, indices, global_idx); }`; - return { - ...metadata, - getShaderSource, - outputs: outputsTensorInfo, - dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)}) - }; - }; - -const createSplitProgramInfoLoader = - (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfoLoader => { - const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); - const metadata: - ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; - }; + return { + name: 'Split', + shaderCache: {hint: attributes.cacheKey}, + getShaderSource, + getRunData: () => ({ + outputs: outputsTensorInfo, + dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + }) + }; +}; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { validateInputs(context.inputs); - context.compute(createSplitProgramInfoLoader(context.inputs, attributes), {inputs: [0]}); + const updatedAttributes = + context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes); + context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); }; export const parseSplitAttributes = (attributes: Record): SplitAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 109c29bfc8a80..e294541a775ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -4,15 +4,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; -export const tileProgramMetadata = { - name: 'Tile', - inputTypes: [GpuDataType.default] -}; - const getRepeats = (repeatsTensorView: TensorView): readonly number[] => Array.from(repeatsTensorView.getBigInt64Array(), Number); @@ -52,18 +47,17 @@ const getOutputShape = (inputShape: readonly number[], repeats: readonly number[ return outputShape; }; -export const createTileProgramInfo = - (tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { - const inputShape = inputs[0].dims; - const repeats: readonly number[] = getRepeats(inputs[1]); - const outputShape = getOutputShape(inputShape, repeats); - const outputSize = ShapeUtil.size(outputShape); +export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { + const inputShape = inputs[0].dims; + const repeats: readonly number[] = getRepeats(inputs[1]); + const outputShape = getOutputShape(inputShape, repeats); + const outputSize = ShapeUtil.size(outputShape); - const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, inputShape); + const output = outputVariable('output', dataType, outputShape); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} @@ -78,19 +72,18 @@ export const createTileProgramInfo = ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; - return { - ...tileProgramMetadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; - }; + return { + name: 'Tile', + shaderCache: {hint: `${repeats}`}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }; +}; export const tile = (context: ComputeContext): void => { validateInputs(context.inputs); - const repeats: readonly number[] = getRepeats(context.inputs[1]); - const cacheHint = repeats.toString(); - context.compute( - {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, - {inputs: [0]}); + context.compute(createTileProgramInfo(context.inputs), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 38dcaeab54c54..c4d43e9f466f5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -4,30 +4,25 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; } -export const transposeProgramMetadata = { - name: 'Transpose', - inputTypes: [GpuDataType.default] -}; - const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Transpose requires 1 input.'); } }; -const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => - (perm && perm.length !== inputShape.length) ? [...(inputShape.keys())].reverse() : perm; +const getAdjustedPerm = (inputRank: number, perm: number[]): number[] => + (perm && perm.length !== inputRank) ? [...(new Array(inputRank).keys())].reverse() : perm; const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => - ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape, perm)); + ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, output: IndicesHelper): string => { const reverseFunc = []; @@ -41,26 +36,23 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou }; export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { - const dataType = inputTensor.dataType; - const inputShape = inputTensor.dims; - const perm = getAdjustedPerm(inputShape, permAttr); - const outputShape = getOutputShape(inputShape, perm); - const rank = inputShape.length; - const outputSize = ShapeUtil.size(outputShape); - // A dims=[${inputs[0].dims.toString()}] - // out Dims=[${unpackedOutputShape.toString()}] - // based on perm=[${perm.toString()}] - - const output = outputVariable('output', dataType, outputShape); - const input = inputVariable('a', dataType, inputShape); + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); + const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); + const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; + const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; + const output = outputVariable('output', inputDataType, outShapeOrRank); + const input = inputVariable('a', inputDataType, inShapeOrRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} - ${permFunctionBody(perm, rank, input, output)} + ${permFunctionBody(perm, inputRank, input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let indices = ${output.offsetToIndices('global_idx')}; let aIndices = perm(indices); @@ -68,20 +60,31 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; return { - ...transposeProgramMetadata, - outputs: [{dims: outputShape, dataType: inputTensor.dataType, gpuDataType: GpuDataType.default}], + name: 'Transpose', + shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: outputSize}, + ], + }; + }, getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { validateInputs(context.inputs); - context.compute({ - ...transposeProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createTransposeProgramInfo(context.inputs[0], attributes.perm) - }); + context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index f08d7a77d1099..119609e06f5a3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {ComputeContext, ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; @@ -29,12 +29,12 @@ const createElementwiseProgramShader = const output = outputVariable('outputData', outputDataType, [vecSize], 4); return ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${additionalImplementation ?? ''} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let a = ${input.getByOffset('global_idx')}; ${output.setByOffset('global_idx', expression)} @@ -42,51 +42,47 @@ const createElementwiseProgramShader = }; const createElementwiseProgramInfo = - (metadata: ProgramMetadata, input: TensorView, outputDataType: number, funcCall: ElementwiseFunctionCall, - additionalImplementation?: string): ProgramInfo => ({ - ...metadata, + (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, + cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({ + name, + shaderCache: {hint: cacheKey, inputDependencies: ['type']}, getShaderSource: shaderHelper => createElementwiseProgramShader( shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation), - outputs: [{dims: input.dims, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: (inputTensors) => - ({x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}) + getRunData: (inputTensors) => ({ + outputs: [{dims: input.dims, dataType: outputDataType}], + dispatchGroup: + {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + ], + }) }); -const createElementwiseProgramInfoLoader = - (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, - cacheKey?: string, outputDataType: number = input.dataType): ProgramInfoLoader => { - const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: cacheKey}; - return { - ...metadata, - get: () => createElementwiseProgramInfo(metadata, input, outputDataType, funcCall, additionalImplementation) - }; - }; - export const abs = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Abs', 'abs')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs')); }; export const acos = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acos', 'acos')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acos', 'acos')); }; export const acosh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acosh', 'acosh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acosh', 'acosh')); }; export const asin = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asin', 'asin')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asin', 'asin')); }; export const asinh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asinh', 'asinh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asinh', 'asinh')); }; export const atan = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atan', 'atan')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atan', 'atan')); }; export const atanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atanh', 'atanh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atanh', 'atanh')); }; export interface CastAttributes extends AttributeWithCacheKey { @@ -119,8 +115,8 @@ export const cast = (context: ComputeContext, attributes: CastAttributes): void default: throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${attributes.to}`); } - context.compute(createElementwiseProgramInfoLoader( - context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to)); + context.compute( + createElementwiseProgramInfo(context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to)); }; export interface ClipAttributes extends AttributeWithCacheKey { @@ -128,10 +124,17 @@ export interface ClipAttributes extends AttributeWithCacheKey { readonly max: number; } -export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { + const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( - createElementwiseProgramInfoLoader( + createElementwiseProgramInfo( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); @@ -139,27 +142,17 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo attributes.cacheKey), {inputs: [0]}); }; -const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); -}; - -export const clip = (context: ComputeContext): void => { - const attributes = generateClipAttributesFromInputs(context.inputs); - clipV10(context, attributes); -}; export const ceil = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Ceil', 'ceil')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil')); }; export const cos = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cos', 'cos')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cos', 'cos')); }; export const cosh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cosh', 'cosh')); }; export interface AlphaAttributes extends AttributeWithCacheKey { @@ -170,7 +163,7 @@ export const parseAlphaAttributes = (attributes: Record): Alpha createAttributeWithCacheKey(attributes as {alpha: number}); export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => { - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` const elu_alpha_: f32 = f32(${attributes.alpha}); @@ -200,79 +193,79 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const exp = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Exp', 'exp')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Exp', 'exp')); }; export const floor = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Floor', 'floor')); }; export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4(0.0))`, `const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey)); }; export const not = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Not', a => `!${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Not', a => `!${a}`)); }; export const neg = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Neg', a => `-${a}`)); }; export const reciprocal = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Reciprocal', a => `1.0/${a}`)); }; export const relu = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'Relu', a => `select(vec4(0.0), ${a}, ${a} > vec4(0.0))`)); }; export const sigmoid = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; export const sin = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sin', 'sin')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; export const sinh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sinh', 'sinh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sinh', 'sinh')); }; export const sqrt = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sqrt', 'sqrt')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sqrt', 'sqrt')); }; export const tan = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tan', 'tan')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { - context.compute(createElementwiseProgramInfoLoader( + context.compute(createElementwiseProgramInfo( context.inputs[0], 'ThresholdedRelu', a => `select(vec4(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, `const thresholded_relu_alpha_: vec4 = vec4(${attributes.alpha});`, attributes.cacheKey)); return 0; }; export const log = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Log', 'log')); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log')); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..6f66dd86b4088 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const createWhereOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, + typeOutput: number) => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!isBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let indexC${x} = offsetC${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); + `; + }; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputDataType = inputs[1].dataType; + + const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); + let outputShape = dimsA; + let outputSize = ShapeUtil.size(dimsA); + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); + if (!calculatedShape) { + throw new Error('Can\'t perform where op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + } + + return { + name: 'Where', + getShaderSource: (shaderHelper) => + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)} + }), + }; +}; + +export const where = (context: ComputeContext): void => { + context.compute(createWhereOpProgramInfo(context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index cf2687e4c7382..0b0a545f46481 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -32,18 +32,12 @@ export class ProgramManager { setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } - run(buildArtifact: Artifact, inputsTensorView: readonly TensorView[], inputs: GpuData[], outputs: GpuData[], - dispatchGroup: [number, number, number]): void { + run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[], + inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], + uniformBufferBinding: GPUBindingResource|undefined): void { const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); - const profilingEnabled = this.backend.supportTimestampQuery && this.backend.env.webgpu.profilingMode === 'default'; - if (profilingEnabled) { - // profiling write start timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0); - } + const computePassEncoder = this.backend.getComputePassEncoder(); computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { @@ -52,6 +46,9 @@ export class ProgramManager { for (const output of outputs) { entries.push({binding: entries.length, resource: {buffer: output.buffer}}); } + if (uniformBufferBinding) { + entries.push({binding: entries.length, resource: uniformBufferBinding}); + } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); computePassEncoder.setBindGroup(0, bindGroup); @@ -60,43 +57,39 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - if (profilingEnabled) { - // profiling write end timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1); - if (this.backend.profilingQueryData == null) { - this.backend.profilingQueryData = + if (this.backend.isQueryEnabled()) { + if (typeof this.backend.queryData === 'undefined') { + this.backend.queryData = this.backend.gpuDataManager.create( // eslint-disable-next-line no-bitwise - this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); + this.backend.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); } - // eslint-disable-next-line no-bitwise - const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + const syncData = this.backend.gpuDataManager.create( + // eslint-disable-next-line no-bitwise + this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet( - this.backend.profilingQuerySet, 0, 2, this.backend.profilingQueryData.buffer, 0); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet!, 0, 2, this.backend.queryData.buffer, 0); this.backend.getCommandEncoder().copyBufferToBuffer( - this.backend.profilingQueryData.buffer, 0, syncData.buffer, 0, 16); + this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); this.backend.flush(); const kernelId = this.backend.currentKernelId!; const kernelInfo = this.backend.kernels.get(kernelId)!; const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`; - syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { + void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); const startTimeU64 = mappedData[0]; const endTimeU64 = mappedData[1]; syncData.buffer.unmap(); - if (typeof this.backend.profilingTimeBase === 'undefined') { - this.backend.profilingTimeBase = startTimeU64; + if (typeof this.backend.queryTimeBase === 'undefined') { + this.backend.queryTimeBase = startTimeU64; } - const startTime = Number(startTimeU64 - this.backend.profilingTimeBase); - const endTime = Number(endTimeU64 - this.backend.profilingTimeBase); + const startTime = Number(startTimeU64 - this.backend.queryTimeBase); + const endTime = Number(endTimeU64 - this.backend.queryTimeBase); if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { throw new RangeError('incorrect timestamp range'); @@ -104,11 +97,11 @@ export class ProgramManager { this.backend.gpuDataManager.release(syncData.id); let inputShapes = ''; - inputsTensorView.forEach((value, i) => { + inputTensorViews.forEach((value, i) => { inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); let outputShapes = ''; - buildArtifact.programInfo.outputs.forEach((value, i) => { + outputTensorViews.forEach((value, i) => { outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); // eslint-disable-next-line no-console @@ -134,7 +127,7 @@ export class ProgramManager { const userCode = programInfo.getShaderSource(shaderHelper); const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code, label: programInfo.name}); - LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); + LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`); const computePipeline = device.createComputePipeline( {compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name}); @@ -142,7 +135,8 @@ export class ProgramManager { return {programInfo, computePipeline}; } - normalizeDispatchGroupSize(dispatchGroup: ReturnType): [number, number, number] { + normalizeDispatchGroupSize(dispatchGroup: ReturnType['dispatchGroup']): + [number, number, number] { const x = typeof dispatchGroup === 'number' ? dispatchGroup : dispatchGroup.x; const y = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.y || 1); const z = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.z || 1); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 78f80b89774e2..23fa33a9bba8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,69 +21,95 @@ export interface GpuData { export interface TensorInfo { dims: readonly number[]; dataType: number; - gpuDataType: GpuDataType; } -export interface ProgramVariable { - type: 'float'|'int'; - name: string; - arrayLength?: number; - data: number|number[]; +export interface ProgramUniform { + type: 'int32'|'float32'|'uint32'; + data: number|readonly number[]; } +/** + * Represent the dependency of a program on a specific input tensor. + * + * - 'none': the shader/uniform does not depend on this input's info + * - 'type': the shader/uniform depends on data type of this input + * - 'rank': the shader/uniform depends on data type and the rank of this input + * - 'dims': the shader/uniform depends on data type and the dims of this input + * - 'data': the shader/uniform depends on data type, the dims and the data of this input + */ +export type ProgramInputTensorInfoDependency = 'none'|'type'|'rank'|'dims'|'data'; -export interface ProgramMetadata { +/** + * Represent information about a program's cache for shader. + */ +export interface ProgramShaderCacheInfo { /** - * the name of the program. used for debugging and profiling + * an optional string as a cache hint in the artifact cache. If this is not specified, the cache hint will be empty. + * + * This hint string should only contains initializing-time information, such as the attributes or any information of + * initializers. It should NOT contain any runtime information, such as the shape of inputs. */ - name: string; + hint?: string; /** - * gpu data types for each input - */ - inputTypes: GpuDataType[]; - /** - * an optional string as a cache hint in the artifact cache + * an optional list of dependencies of the program on the input tensors. If this is not specified, the program depends + * on 'dims' of all inputs. */ - cacheHint?: string; + inputDependencies?: ProgramInputTensorInfoDependency[]; } /** - * A ProgramInfoLoader allows + * Represent information about a program's cache for uniform. */ -export interface ProgramInfoLoader extends ProgramMetadata { +export interface ProgramUniformCacheInfo { + /** + * an optional string as a cache hint in the uniform cache. If this is not specified, the cache hint will be empty. + * + * This hint string should only contains runtime information, such as the shape of inputs. + */ + hint?: string; + /** - * a function to get the program info + * an optional list of dependencies of the program on the input tensors. If this is not specified, the program depends + * on 'none' of all inputs. */ - get(): ProgramInfo; + inputDependencies?: ProgramInputTensorInfoDependency[]; } + /** * A set of data that represent a shader program */ -export interface ProgramInfo extends ProgramMetadata { +export interface ProgramInfo { /** - * information of uniform variables + * the name of the program. used for debugging and profiling */ - variables?: ProgramVariable[]; + name: string; + /** - * tensor info for outputs + * an optional object describing the cache information of the program shader. + * + * If this is not specified, assume hint is empty and inputDependencies are ['dims'] for all inputs. */ - outputs: TensorInfo[]; + shaderCache?: ProgramShaderCacheInfo; + /** - * the shader's processing source code + * the shader's processing source code. + * + * This function will be called when shader cache missed. */ getShaderSource: (shaderHelper: ShaderHelper) => string; + /** - * default is "main" + * A function to get run data required to run the program. + * + * This function will be called every time the program is executed. Should keep this function as simple as possible. */ - // entryPoint: string; - - dispatchGroup: (inputs: readonly TensorView[]) => { - x: number; - y?: number; - z?: number; + getRunData: (inputs: readonly TensorView[]) => { + outputs: readonly TensorInfo[]; + dispatchGroup: {x: number; y?: number; z?: number}; + programUniforms?: readonly ProgramUniform[]; }; } @@ -143,7 +169,6 @@ export interface ComputeContext { */ readonly outputCount: number; - compute(program: ProgramInfoLoader|ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): - TensorView[]; + compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; } diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index e5a2d8c2351b8..efeb086256cf3 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -1,22 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; +import type {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -/** - * tuple elements are: ORT element type; dims; tensor data - */ -export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType]; +export type SerializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; -/** - * tuple elements are: InferenceSession handle; input names; output names - */ -export type SerializableSessionMetadata = [number, string[], string[]]; +export type GpuBufferMetadata = { + gpuBuffer: Tensor.GpuBufferType; + download?: () => Promise; + dispose?: () => void; +}; -/** - * tuple elements are: modeldata.offset, modeldata.length - */ -export type SerializableModeldata = [number, number]; +export type UnserializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + +export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; + +export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; + +export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; interface MessageError { err?: string; @@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError { interface MessageRun extends MessageError { type: 'run'; in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[]; + sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; options: InferenceSession.RunOptions; }; - out?: SerializableTensor[]; + out?: SerializableTensorMetadata[]; } interface MesssageEndProfiling extends MessageError { @@ -69,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 2ea3645bf387e..1cb6d9e391e4f 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -3,15 +3,47 @@ /// -import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +// +// * type hack for "HTMLImageElement" +// +// in typescript, the type of "HTMLImageElement" is defined in lib.dom.d.ts, which is conflict with lib.webworker.d.ts. +// when we use webworker, the lib.webworker.d.ts will be used, which does not have HTMLImageElement defined. +// +// we will get the following errors complaining that HTMLImageElement is not defined: +// +// ==================================================================================================================== +// +// ../common/dist/cjs/tensor-factory.d.ts:187:29 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 187 fromImage(imageElement: HTMLImageElement, options?: TensorFromImageElementOptions): +// Promise | TypedTensor<'uint8'>>; +// ~~~~~~~~~~~~~~~~ +// +// node_modules/@webgpu/types/dist/index.d.ts:83:7 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 83 | HTMLImageElement +// ~~~~~~~~~~~~~~~~ +// +// ==================================================================================================================== +// +// `HTMLImageElement` is only used in type declaration and not in real code. So we define it as `unknown` here to +// bypass the type check. +// +declare global { + type HTMLImageElement = unknown; +} + +import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { switch (ev.data.type) { case 'init-wasm': try { - initializeWebAssembly(ev.data.in) + initializeWebAssembly(ev.data.in!) .then( () => postMessage({type: 'init-wasm'} as OrtWasmMessage), err => postMessage({type: 'init-wasm', err} as OrtWasmMessage)); @@ -21,11 +53,10 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'init-ort': try { - initRuntime(ev.data.in).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ - type: 'init-ort', - err - } as OrtWasmMessage)); - postMessage({type: 'init-ort'} as OrtWasmMessage); + initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ + type: 'init-ort', + err + } as OrtWasmMessage)); } catch (err) { postMessage({type: 'init-ort', err} as OrtWasmMessage); } @@ -59,8 +90,7 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'release': try { - const handler = ev.data.in!; - releaseSession(handler); + releaseSession(ev.data.in!); postMessage({type: 'release'} as OrtWasmMessage); } catch (err) { postMessage({type: 'release', err} as OrtWasmMessage); @@ -69,10 +99,16 @@ self.onmessage = (ev: MessageEvent): void => { case 'run': try { const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; - run(sessionId, inputIndices, inputs, outputIndices, options) + run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) .then( outputs => { - postMessage({type: 'run', out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs)); + if (outputs.some(o => o[3] !== 'cpu')) { + postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'}); + } else { + postMessage( + {type: 'run', out: outputs} as OrtWasmMessage, + extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + } }, err => { postMessage({type: 'run', err} as OrtWasmMessage); @@ -90,6 +126,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-worker/tsconfig.json b/js/web/lib/wasm/proxy-worker/tsconfig.json new file mode 100644 index 0000000000000..ec1044612a569 --- /dev/null +++ b/js/web/lib/wasm/proxy-worker/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig.json", + "compilerOptions": { + "lib": ["WebWorker"] + }, + "include": ["main.ts", "../../build-def.d.ts"], + "exclude": [] +} diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 815b223e40379..069a1fa452dbc 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,7 +3,7 @@ import {Env, env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -22,8 +22,9 @@ const createSessionAllocateCallbacks: Array> = []; const createSessionCallbacks: Array> = []; const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; +const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } + break; default: } }; @@ -121,9 +129,18 @@ export const initializeWebAssemblyInstance = async(): Promise => { return new Promise((resolve, reject) => { proxyWorker?.terminate(); - // eslint-disable-next-line @typescript-eslint/no-var-requires, @typescript-eslint/no-require-imports - proxyWorker = require('worker-loader?inline=no-fallback!./proxy-worker/main').default() as Worker; + + const workerUrl = URL.createObjectURL(new Blob( + [ + // This require() function is handled by esbuild plugin to load file content as string. + // eslint-disable-next-line @typescript-eslint/no-require-imports + require('./proxy-worker/main') + ], + {type: 'text/javascript'})); + proxyWorker = new Worker(workerUrl, {name: 'ort-wasm-proxy-worker'}); + proxyWorker.onerror = (ev: ErrorEvent) => reject(ev); proxyWorker.onmessage = onProxyWorkerMessage; + URL.revokeObjectURL(workerUrl); initWasmCallbacks = [resolve, reject]; const message: OrtWasmMessage = {type: 'init-wasm', in : env.wasm}; proxyWorker.postMessage(message); @@ -177,6 +194,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt export const createSession = async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } ensureWorker(); return new Promise((resolve, reject) => { createSessionCallbacks.push([resolve, reject]); @@ -202,17 +223,27 @@ export const releaseSession = async(sessionId: number): Promise => { }; export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], + outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check inputs location + if (inputs.some(t => t[3] !== 'cpu')) { + throw new Error('input tensor on GPU is not supported for proxy.'); + } + // check outputs location + if (outputs.some(t => t)) { + throw new Error('pre-allocated output tensor is not supported for proxy.'); + } ensureWorker(); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { runCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}}; - proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs)); + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = + {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { - return core.run(sessionId, inputIndices, inputs, outputIndices, options); + return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; @@ -228,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +}; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 55% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts index d8c5ae7886fe4..3ca34d957c572 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -1,17 +1,44 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {readFile} from 'fs'; -import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; -import {promisify} from 'util'; +import {readFile} from 'node:fs/promises'; +import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; +import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; -export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { +export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { + switch (tensor.location) { + case 'cpu': + return [tensor.type, tensor.dims, tensor.data, 'cpu']; + case 'gpu-buffer': + return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + default: + throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); + } +}; + +export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { + switch (tensor[3]) { + case 'cpu': + return new Tensor(tensor[0], tensor[2], tensor[1]); + case 'gpu-buffer': { + const dataType = tensor[0]; + if (!isGpuBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); + } + const {gpuBuffer, download, dispose} = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + } + default: + throw new Error(`invalid data location: ${tensor[3]}`); + } +}; + +export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler { private sessionId: number; inputNames: string[]; @@ -29,19 +56,18 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - const model = await promisify(readFile)(pathOrBuffer); + const model = await readFile(pathOrBuffer); [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); } else { // browser @@ -74,25 +100,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { inputIndices.push(index); }); + const outputArray: Array = []; const outputIndices: number[] = []; Object.entries(fetches).forEach(kvp => { const name = kvp[0]; - // TODO: support pre-allocated output + const tensor = kvp[1]; const index = this.outputNames.indexOf(name); if (index === -1) { throw new Error(`invalid output '${name}'`); } + outputArray.push(tensor); outputIndices.push(index); }); - const outputs = - await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options); + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - const result: SessionHandler.ReturnType = {}; - for (let i = 0; i < outputs.length; i++) { - result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]); + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } - return result; + return resultMap; } startProfiling(): void { diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts new file mode 100644 index 0000000000000..7de3f4dc2c89e --- /dev/null +++ b/js/web/lib/wasm/session-handler-training.ts @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + /** + * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the + * corresponding name as a number referring to the index in the list of names provided. + * + * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType + * @param names either inputNames or outputNames + * @returns a tuple of a list of values and a list of indices. + */ + convertMapIntoValuesArrayAndIndicesArray( + feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { + const values: T[] = []; + const indices: number[] = []; + Object.entries(feeds).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = names.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}`); + } + values.push(tensor); + indices.push(index); + }); + + const uList = values.map(mapFunc); + return [values, indices, uList]; + } + + /** + * Helper method that converts the TensorMetadata that the wasm-core functions return to the + * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the + * corresponding result. + * + * @param results used to populate the resultMap if there is no value for that outputName already + * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results + * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. + * @returns a map of output names and OnnxValues. + */ + convertTensorMetadataToReturnType( + results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); + } + return resultMap; + } + + async runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.inputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.outputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 2659b471733f5..45ea48a2df209 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -75,6 +75,19 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`); } } + if (webnnOptions?.numThreads) { + let numThreads = webnnOptions.numThreads; + // Just ignore invalid webnnOptions.numThreads. + if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) { + numThreads = 0; + } + const keyDataOffset = allocWasmString('numThreads', allocs); + const valueDataOffset = allocWasmString(numThreads.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`); + } + } if (webnnOptions?.powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); @@ -88,6 +101,21 @@ const setExecutionProviders = break; case 'webgpu': epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + const keyDataOffset = allocWasmString('preferredLayout', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError( + `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); + } + } + } break; case 'wasm': case 'cpu': diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 389773f3e8884..b9eff45e890c4 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; + +/** + * Check whether the given tensor type is supported by GPU buffer + */ +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || + type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + +/** + * Map string data location to integer value + */ +export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => { + switch (location) { + case 'none': + return 0; + case 'cpu': + return 1; + case 'cpu-pinned': + return 2; + case 'texture': + return 3; + case 'gpu-buffer': + return 4; + default: + throw new Error(`unsupported data location: ${location}`); + } +}; + +/** + * Map integer data location to string value + */ +export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index fcca82ab2aa54..3aacf8f4d90e0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,13 +3,15 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -57,15 +59,46 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; +}; + +/** + * valid data locations for input/output tensors. + */ +type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; + +type IOBindingState = { + /** + * the handle of IO binding. + */ + readonly handle: number; + + /** + * the preferred location for each output tensor. + * + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + */ + readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; + + /** + * enum value of the preferred location for each output tensor. + */ + readonly outputPreferredLocationsEncoded: readonly number[]; }; /** - * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded + * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState */ -type SessionMetadata = [number, number[], number[]]; +type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], + bindingState: IOBindingState|null +]; const activeSessions = new Map(); +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; + /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. * @returns a 2-elements tuple - the pointer and size of the allocated buffer @@ -92,6 +125,7 @@ export const createSessionFinalize = let sessionHandle = 0; let sessionOptionsHandle = 0; + let ioBindingHandle = 0; let allocs: number[] = []; const inputNamesUTF8Encoded = []; const outputNamesUTF8Encoded = []; @@ -108,6 +142,7 @@ export const createSessionFinalize = const inputNames = []; const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { @@ -122,15 +157,45 @@ export const createSessionFinalize = checkLastError('Can\'t get an output name.'); } outputNamesUTF8Encoded.push(name); - outputNames.push(wasm.UTF8ToString(name)); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + outputPreferredLocations.push(location); + } } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]); + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } + + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; + } + + activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } + if (sessionHandle !== 0) { wasm._OrtReleaseSession(sessionHandle); } @@ -161,7 +226,13 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { + wasm._OrtReleaseBinding(ioBindingState.handle); + } + + wasm.jsepUnregisterBuffers?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -169,18 +240,84 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; +export const prepareInputOutputTensor = + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): + void => { + if (!tensor) { + tensorHandles.push(0); + return; + } + + const wasm = getInstance(); + + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; + + let rawData: number; + let dataByteLength: number; + + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); + } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; + /** * perform inference run */ export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -188,171 +325,200 @@ export const run = async( let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); - } + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - checkLastError(`Can't create tensor for input[${i}].`); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } + + // process inputs for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + } } + + // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + } } + } - // jsepOnRunStart is only available when JSEP is enabled. - wasm.jsepOnRunStart?.(sessionId); + let errorCode: number; - // support RunOptions - let errorCode = wasm._OrtRun( + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + } else { + errorCode = await wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); + } - const runPromise = wasm.jsepRunPromise; - if (runPromise) { - // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. - // - // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to - // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than - // the async functions are finished. - // - // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of - // OrtRun(). If the promise exists, we need to await for the promise to be resolved. - errorCode = await runPromise; - } + if (errorCode !== 0) { + checkLastError('failed to call OrtRun().'); + } + + const output: TensorMetadata[] = []; - const jsepOnRunEnd = wasm.jsepOnRunEnd; - if (jsepOnRunEnd) { - // jsepOnRunEnd is only available when JSEP is enabled. - // - // it returns a promise, which is resolved or rejected when the following async functions are finished: - // - collecting GPU validation errors. - await jsepOnRunEnd(sessionId); + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; } - const output: SerializableTensor[] = []; + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - if (errorCode !== 0) { - checkLastError('failed to call OrtRun().'); - } + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); + if (type === 'string') { + if (preferredLocation === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + output.push([type, dims, stringData, 'cpu']); + } else { + // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU + // tensor for it. There is no mapping GPU buffer for an empty tensor. + if (preferredLocation === 'gpu-buffer' && size > 0) { + const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const elementSize = getTensorElementSize(dataType); + if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); } - output.push([type, dims, stringData]); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, dims, { + gpuBuffer, + download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + dispose: () => { + wasm._OrtReleaseTensor(tensor); + } + }, + 'gpu-buffer' + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); + output.push([type, dims, data, 'cpu']); } + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } } + } - return output; - } finally { - wasm.stackRestore(beforeRunStack); + if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); } + + return output; } finally { - inputValues.forEach(v => wasm._OrtReleaseTensor(v)); - inputAllocs.forEach(p => wasm._free(p)); + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); @@ -380,11 +546,11 @@ export const endProfiling = (sessionId: number): void => { wasm._OrtFree(profileFileName); }; -export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => { +export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => { const buffers: ArrayBufferLike[] = []; for (const tensor of tensors) { const data = tensor[2]; - if (!Array.isArray(data) && data.buffer) { + if (!Array.isArray(data) && 'buffer' in data) { buffers.push(data.buffer); } } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 7648f0c473f07..2b7d492cc70ba 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -1,15 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import * as path from 'node:path'; import {Env} from 'onnxruntime-common'; -import * as path from 'path'; import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ -const ortWasmFactory: EmscriptenModuleFactory = - BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +let ortWasmFactory: EmscriptenModuleFactory; + +if (!BUILD_DEFS.DISABLE_TRAINING) { + ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : @@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { - if (useThreads) { - return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm'; + if (useSimd) { + if (!BUILD_DEFS.DISABLE_TRAINING) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; } else { - return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm'; + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } }; @@ -128,7 +137,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise typeof Blob !== 'undefined') { return URL.createObjectURL(new Blob( [ - // This require() function is handled by webpack to load file content of the corresponding .worker.js + // This require() function is handled by esbuild plugin to load file content as string. // eslint-disable-next-line @typescript-eslint/no-require-imports require('./binding/ort-wasm-threaded.worker.js') ], @@ -161,7 +170,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise if (typeof Blob === 'undefined') { config.mainScriptUrlOrBlob = path.join(__dirname, 'ort-wasm-threaded.js'); } else { - const scriptSourceCode = `var ortWasmThreaded=(function(){var _scriptDir;return ${factory.toString()}})();`; + const scriptSourceCode = `var ortWasmThreaded=${factory.toString()};`; config.mainScriptUrlOrBlob = new Blob([scriptSourceCode], {type: 'text/javascript'}); } } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..c0a4235113148 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,455 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, Tensor} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; +import {setSessionOptions} from './session-options'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; + +const NO_TRAIN_FUNCS_MSG = + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + +/** + * Runs the checkLastError function which will throw an error, if the provided error code matches the specified + * pattern for an error code. + * @param errCode number to evaluated for if it's an error + * @param message message to pass into checkLastError + * @param checkNeqZero when true, treats not equal to zero as an error. + * When false, treats equal to zero as an error. + */ +const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { + if (checkNeqZero && errCode !== 0) { + checkLastError(message); + } else if (!checkNeqZero && errCode === 0) { + checkLastError(message); + } +}; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + const [checkpointDataOffset, checkpointDataLength] = checkpointData; + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); + + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + +/** + * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the + * WASM tensors. + * + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const count = indices.length; + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + const wasm = getInstance(); + const valuesOffset = wasm.stackAlloc(count * 4); + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information + * associated with the tensor handle. + * + * @param outputValuesOffset + * @param outputCount + * @returns list of TensorMetadata retrieved from the output handles. + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], + outputTensors: Array) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; + } + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } + + return output; + }; + +export const runTrainStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingRunTrainStep) { + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +export const getContiguousParameters = + async(trainingSessionId: number, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + let tensor = 0; + + // allocates a buffer of the correct size on the WASM heap + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm._malloc(paramsByteLength); + + // handles the dimensions-related createTensor parameters + const dims = [parametersSize]; + + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + + try { + // wraps allocated array in a tensor + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError( + tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push([tensorTypeAsString, dims, data, locationAsString]); + if (output.length !== 1) { + throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm.stackRestore(stack); + } +}; + +export const loadParametersBuffer = + async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + // allocates & copies JavaScript buffer to WASM heap + const bufferByteLength = buffer.length; + const bufferCount = bufferByteLength / 4; + const bufferOffset = wasm._malloc(bufferByteLength); + wasm.HEAPU8.set(buffer, bufferOffset); + + // allocates and handles moving dimensions information to WASM memory + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor = 0; + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +}; + +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + }; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 9567bc172c9ed..890c5a0f34765 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -25,7 +25,7 @@ "@types/minimatch": "^5.1.2", "@types/minimist": "^1.2.2", "@types/platform": "^1.3.4", - "@webgpu/types": "^0.1.30", + "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", @@ -323,9 +323,9 @@ } }, "node_modules/@webgpu/types": { - "version": "0.1.30", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", - "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "version": "0.1.38", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.38.tgz", + "integrity": "sha512-7LrhVKz2PRh+DD7+S+PVaFd5HxaWQvoMqBbsV9fNJO1pjUs1P8bM2vQVNfk+3URTqbuTI7gkXi0rfsN0IadoBA==", "dev": true }, "node_modules/accepts": { @@ -3767,9 +3767,9 @@ } }, "@webgpu/types": { - "version": "0.1.30", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", - "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "version": "0.1.38", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.38.tgz", + "integrity": "sha512-7LrhVKz2PRh+DD7+S+PVaFd5HxaWQvoMqBbsV9fNJO1pjUs1P8bM2vQVNfk+3URTqbuTI7gkXi0rfsN0IadoBA==", "dev": true }, "accepts": { diff --git a/js/web/package.json b/js/web/package.json index 8ae5b733e5f21..9b4531d7766fe 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -20,10 +20,11 @@ }, "scripts": { "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", - "prepare": "tsc", + "prepare": "tsc --build ./script", "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", + "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", "prepack": "node ./script/build && node ./script/prepack" @@ -42,7 +43,7 @@ "@types/minimatch": "^5.1.2", "@types/minimist": "^1.2.2", "@types/platform": "^1.3.4", - "@webgpu/types": "^0.1.30", + "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", @@ -66,18 +67,48 @@ "main": "dist/ort-web.node.js", "exports": { ".": { - "node": { - "types": "./types.d.ts", - "default": "./dist/ort-web.node.js" - }, + "node": "./dist/ort.node.min.js", "default": { - "types": "./types.d.ts", - "default": "./dist/ort.min.js" + "import": "./dist/esm/ort.min.js", + "require": "./dist/cjs/ort.min.js", + "default": { + "development": "./dist/ort.js", + "default": "./dist/ort.min.js" + } } }, + "./experimental": { + "import": "./dist/esm/ort.all.min.js", + "require": "./dist/cjs/ort.all.min.js", + "default": { + "development": "./dist/ort.all.js", + "default": "./dist/ort.all.min.js" + } + }, + "./wasm": { + "import": "./dist/esm/ort.wasm.min.js", + "require": "./dist/cjs/ort.wasm.min.js", + "default": "./dist/ort.wasm.min.js" + }, + "./wasm-core": { + "import": "./dist/esm/ort.wasm-core.min.js", + "require": "./dist/cjs/ort.wasm-core.min.js", + "default": "./dist/ort.wasm-core.min.js" + }, + "./webgl": { + "import": "./dist/esm/ort.webgl.min.js", + "require": "./dist/cjs/ort.webgl.min.js", + "default": "./dist/ort.webgl.min.js" + }, "./webgpu": { - "types": "./types.d.ts", + "import": "./dist/esm/ort.webgpu.min.js", + "require": "./dist/cjs/ort.webgpu.min.js", "default": "./dist/ort.webgpu.min.js" + }, + "./training": { + "import": "./dist/esm/ort.training.wasm.min.js", + "require": "./dist/cjs/ort.training.wasm.min.js", + "default": "./dist/ort.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index d3a5be429bfa1..5151f27582c1f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -1,186 +1,447 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {spawnSync} from 'child_process'; -import * as fs from 'fs-extra'; +import * as esbuild from 'esbuild'; import minimist from 'minimist'; -import npmlog from 'npmlog'; -import * as path from 'path'; - -// CMD args -const args = minimist(process.argv); - -// --bundle-mode=prod (default) -// --bundle-mode=dev -// --bundle-mode=perf -// --bundle-mode=node -const MODE = args['bundle-mode'] || 'prod'; -if (['prod', 'dev', 'perf', 'node'].indexOf(MODE) === -1) { - throw new Error(`unknown build mode: ${MODE}`); -} +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; -// --wasm (default) -// --no-wasm -const WASM = typeof args.wasm === 'undefined' ? true : !!args.wasm; - -// -a; --analyzer -const ANALYZER = !!args.a || !!args.analyzer; - -// -f; --filter= -const FILTER = args.f || args.filter; - -// Path variables -const ROOT_FOLDER = path.join(__dirname, '..'); -const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); -const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); -const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); -const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); -const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); -const WASM_BINDING_THREADED_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.js'); -const WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH = - path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.min.js'); -const WASM_BINDING_THREADED_MIN_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.worker.js'); - -const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist'); -const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm'); -const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm'); -const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm'); -const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm'); -const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm'); -const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm'); -const WASM_THREADED_WORKER_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.worker.js'); -const WASM_THREADED_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.js'); -const WASM_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); - -function validateFile(path: string): void { - npmlog.info('Build', `Ensure file: ${path}`); - if (!fs.pathExistsSync(path)) { - throw new Error(`file does not exist: ${path}`); - } - if (fs.statSync(path).size === 0) { - throw new Error(`file is empty: ${path}`); - } -} +/** + * @summary Build script for ort-web using esbuild. + */ -if (WASM) { - npmlog.info('Build', 'Validating WebAssembly artifacts...'); - try { - validateFile(WASM_BINDING_JS_PATH); - validateFile(WASM_BINDING_THREADED_JS_PATH); - validateFile(WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH); - validateFile(WASM_BINDING_THREADED_WORKER_JS_PATH); - validateFile(WASM_WASM_PATH); - validateFile(WASM_THREADED_WASM_PATH); - validateFile(WASM_SIMD_WASM_PATH); - validateFile(WASM_SIMD_THREADED_WASM_PATH); - validateFile(WASM_SIMD_JSEP_WASM_PATH); - validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH); - } catch (e) { - npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`); - throw e; - } - npmlog.info('Build', 'Validating WebAssembly artifacts... DONE'); +const args = minimist(process.argv.slice(2)); +/** + * --bundle-mode=prod (default) + * Build multiple ort-web bundles for production. + * + * --bundle-mode=dev + * Build a single ort-web bundle for development, and a test bundle. + * + * --bundle-mode=perf + * Build a single ort-web bundle for performance test, and a test bundle. + * + * --bundle-mode=node + * Build a single ort-web bundle for nodejs. + */ +const BUNDLE_MODE: 'prod'|'dev'|'perf'|'node' = args['bundle-mode'] || 'prod'; + +/** + * --debug + * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Simple bundle analysis will be printed. + * + * --debug=verbose + * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Detailed bundle analysis will be + * printed. + * + * --debug=save + * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a + * file as JSON. + */ +const DEBUG = args.debug; // boolean|'verbose'|'save' - const VERSION = require(path.join(__dirname, '../package.json')).version; - const COPYRIGHT_BANNER = `/*! - * ONNX Runtime Web v${VERSION} +const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ +const DEFAULT_DEFINE = { + 'BUILD_DEFS.DISABLE_WEBGL': 'false', + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', + 'BUILD_DEFS.DISABLE_WASM': 'false', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', + 'BUILD_DEFS.DISABLE_TRAINING': 'true', +}; + +const COPYRIGHT_HEADER = `/*! + * ONNX Runtime Web v${require('../package.json').version} * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. - */ -`; - - npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"...'); - try { - const terser = spawnSync( - 'npx', - [ - 'terser', WASM_BINDING_THREADED_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', '--mangle', - 'reserved=[_scriptDir,startWorker]', '--module' - ], - {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); - if (terser.status !== 0) { - console.error(terser.error); - process.exit(terser.status === null ? undefined : terser.status); - } + */`; - fs.writeFileSync(WASM_BINDING_THREADED_MIN_JS_PATH, terser.stdout); - fs.writeFileSync(WASM_THREADED_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); +interface OrtBuildOptions { + isProduction?: boolean; + isNode?: boolean; + format: 'iife'|'cjs'|'esm'; + outputBundleName: string; + define?: Record; +} - validateFile(WASM_BINDING_THREADED_MIN_JS_PATH); - validateFile(WASM_THREADED_JS_PATH); - } catch (e) { - npmlog.error('Build', `Failed to run terser on ort-wasm-threaded.js. ERR: ${e}`); - throw e; +async function buildBundle(options: esbuild.BuildOptions) { + const result = await esbuild.build({ + logLevel: DEBUG ? (DEBUG === 'verbose' || DEBUG === 'save' ? 'verbose' : 'debug') : 'info', + metafile: !!DEBUG, + absWorkingDir: SOURCE_ROOT_FOLDER, + bundle: true, + banner: {js: COPYRIGHT_HEADER}, + ...options + }); + if (DEBUG) { + if (DEBUG === 'save') { + await fs.writeFile( + `${path.basename(options.outfile!)}.esbuild.metafile.json`, JSON.stringify(result.metafile!, null, 2)); + } else { + console.log(await esbuild.analyzeMetafile(result.metafile!, {verbose: DEBUG === 'verbose'})); + } } - npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"... DONE'); - - npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"...'); - try { - const terser = spawnSync( - 'npx', - [ - 'terser', WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', - '--mangle', 'reserved=[_scriptDir,startWorker]', '--module' - ], - {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); - if (terser.status !== 0) { - console.error(terser.error); - process.exit(terser.status === null ? undefined : terser.status); + return result; +} + +async function minifyCode(sourceCode: string): Promise { + const result = await esbuild.transform(sourceCode, { + minify: true, + legalComments: 'none', + }); + return result.code; +} + +async function buildOrt({ + isProduction = false, + isNode = false, + format, + outputBundleName, + define = DEFAULT_DEFINE, +}: OrtBuildOptions) { + // #region Plugin: resolve ignore imports + + /** + * This plugin is used to ignore a few nodejs imports that are not used in the browser. Those imported functions are + * not really used in the browser because they are usually put behind a feature check. However, esbuild does not know + * that. It will complain about those imports are not available in the browser. + * + * This plugin will ignore those imports and replace them with empty exports. + */ + const excludeNodejsImports = { + name: 'exclude-nodejs-imports', + setup(build: esbuild.PluginBuild) { + build.onResolve({filter: /(^node:|^worker_threads$|^fs$|^path$|^perf_hooks$|^os$)/}, args => ({ + namespace: 'nodejs-ignore', + path: args.path, + sideEffects: false, + })); + build.onLoad({filter: /.*/, namespace: 'nodejs-ignore'}, args => { + switch (args.path) { + case 'node:fs/promises': + case 'node:fs': + case 'fs': + return {contents: 'export const readFile = undefined;'}; + case 'node:os': + case 'os': + return {contents: 'export const cpus = undefined;'}; + case 'node:path': + case 'path': + return {contents: 'export const join = undefined;'}; + default: + return {contents: ''}; + } + }); + }, + }; + // #endregion + + // #region Plugin: web assembly multi-thread worker loader + + /** + * This plugin is used to load web assembly multi-thread worker code as string. + * + * This allows to create the worker from a Blob, so we don't need to create a separate file for the worker. + */ + const wasmThreadedHandler = { + name: 'wasm-threaded-handler', + setup(build: esbuild.PluginBuild) { + build.onLoad({filter: /[\\/]ort-wasm-threaded\.worker\.js$/}, async args => { + let contents = await fs.readFile(args.path, {encoding: 'utf-8'}); + if (isProduction) { + contents = await minifyCode(contents); + } + return {loader: 'text', contents}; + }); + }, + }; + // #endregion + + // #region Plugin: generated emscripten .js loader + + /** + * This plugin is used to patch the generated emscripten .js file for multi-thread build. + * + * Since we use inline worker for multi-thread, we make an optimization to use function.toString() to get the + * implementation of the exported `ortWasmThreaded` function to reduce the size of the bundle. However, the generated + * function uses a variable `_scriptDir` which is defined inside an IIFE closure. When we use function.toString(), the + * worker code will throw "_scriptDir is not defined" error. + * + * To fix this error, we need to patch the generated code to replace access to `_scriptDir` with `typeof _scriptDir + * !== "undefined" && _scriptDir`. + */ + const emscriptenThreadedJsHandler = { + name: 'emscripten-threaded-js-handler', + setup(build: esbuild.PluginBuild) { + build.onLoad({filter: /ort-wasm.*-threaded.*\.js$/}, async args => { + let contents = await fs.readFile(args.path, {encoding: 'utf-8'}); + // For debug build, Emscripten generates the following code: + // + // if (_scriptDir) { + // scriptDirectory = _scriptDir; + // } + // + // We replace it with: + // + // if (typeof _scriptDir !== "undefined" && _scriptDir) { + // scriptDirectory = _scriptDir; + // } + contents = contents.replace('if (_scriptDir) {', 'if (typeof _scriptDir !== "undefined" && _scriptDir) {'); + + // For release build, Emscripten generates the following code: + // + // ...,_scriptDir&&(H=_scriptDir),... + // We replace it with: + // ...,(typeof _scriptDir !== "undefined" && _scriptDir)&&(H=_scriptDir),... + contents = + contents.replace(/_scriptDir(&&\(.+=_scriptDir\))/, '(typeof _scriptDir !== "undefined" && _scriptDir)$1'); + + return {contents}; + }); } + }; + // #endregion - fs.writeFileSync(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH, terser.stdout); - fs.writeFileSync(WASM_SIMD_THREADED_JSEP_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + // #region Plugin: proxy worker loader - validateFile(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH); - validateFile(WASM_SIMD_THREADED_JSEP_JS_PATH); - } catch (e) { - npmlog.error('Build', `Failed to run terser on ort-wasm-threaded.js. ERR: ${e}`); - throw e; - } - npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"... DONE'); - - npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.worker.js"...'); - try { - const terser = spawnSync( - 'npx', - [ - 'terser', WASM_BINDING_THREADED_WORKER_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', - '--mangle', 'reserved=[_scriptDir,startWorker]', '--toplevel' - ], - {shell: true, encoding: 'utf-8'}); - if (terser.status !== 0) { - console.error(terser.error); - process.exit(terser.status === null ? undefined : terser.status); + /** + * This plugin is used to load proxy worker code as string. + */ + const proxyWorkerHandler = { + name: 'proxy-worker-handler', + setup(build: esbuild.PluginBuild) { + build.onResolve( + {filter: /proxy-worker\/main$/}, + async args => ({path: args.path, namespace: 'proxy-worker', pluginData: args.resolveDir})); + + build.onLoad({filter: /.*/, namespace: 'proxy-worker'}, async args => { + const result = await buildBundle({ + entryPoints: [path.resolve(args.pluginData, args.path)], + outfile: `web/dist/${outputBundleName}.proxy.js`, + platform: 'browser', + plugins: [excludeNodejsImports, wasmThreadedHandler, emscriptenThreadedJsHandler], + define: { + ...build.initialOptions.define, + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', + }, + sourcemap: isProduction ? false : 'inline', + minify: isProduction, + write: false, + }); + + return {loader: 'text', contents: result.outputFiles![0].text}; + }); + }, + }; + // #endregion + + await buildBundle({ + entryPoints: ['web/lib/index.ts'], + outfile: `web/dist/${outputBundleName}.js`, + platform: isNode ? 'node' : 'browser', + format, + globalName: 'ort', + plugins: isNode ? undefined : + [excludeNodejsImports, proxyWorkerHandler, wasmThreadedHandler, emscriptenThreadedJsHandler], + external: isNode ? ['onnxruntime-common'] : undefined, + define, + sourcemap: isProduction ? 'linked' : 'inline', + minify: isProduction, + }); +} + +async function buildTest() { + const isProduction = BUNDLE_MODE === 'perf'; + + await buildBundle({ + absWorkingDir: path.join(SOURCE_ROOT_FOLDER, 'web/test'), + + entryPoints: ['test-main.ts'], + outfile: isProduction ? 'ort.test.min.js' : 'ort.test.js', + platform: 'browser', + format: 'iife', + define: DEFAULT_DEFINE, + sourcemap: isProduction ? false : 'inline', + sourceRoot: path.join(SOURCE_ROOT_FOLDER, 'web/test'), + external: ['../../node'], + plugins: [ + // polyfill nodejs modules + require('esbuild-plugin-polyfill-node').polyfillNode({globals: false}), + // make "ort" external + { + name: 'make-ort-external', + setup(build: esbuild.PluginBuild) { + build.onResolve( + {filter: /^onnxruntime-common$/}, + _args => ({path: 'onnxruntime-common', namespace: 'make-ort-external'})); + build.onLoad( + {filter: /.*/, namespace: 'make-ort-external'}, + _args => ({contents: 'module.exports = globalThis.ort;'})); + } + } + ], + minify: isProduction, + }); +} + +async function main() { + // tasks for each esbuild bundle + const buildTasks: Array> = []; + /** + * add one build task + */ + const addBuildTask = async (task: Promise) => { + if (DEBUG) { + // in DEBUG mode, build sequentially + await task; + } else { + buildTasks.push(task); } + }; + /** + * add all 6 build tasks for web bundles. Includes: + * - IIFE, debug: [name].js + * - IIFE, production: [name].min.js + * - CJS, debug: cjs/[name].js + * - CJS, production: cjs/[name].min.js + * - ESM, debug: esm/[name].js + * - ESM, production: esm/[name].min.js + */ + const addAllWebBuildTasks = async (options: Omit) => { + // [name].js + await addBuildTask(buildOrt({ + ...options, + format: 'iife', + })); + // [name].min.js + await addBuildTask(buildOrt({ + ...options, + outputBundleName: options.outputBundleName + '.min', + format: 'iife', + isProduction: true, + })); + // cjs/[name].js + await addBuildTask(buildOrt({ + ...options, + outputBundleName: 'cjs/' + options.outputBundleName, + format: 'cjs', + })); + // cjs/[name].min.js + await addBuildTask(buildOrt({ + ...options, + outputBundleName: 'cjs/' + options.outputBundleName + '.min', + format: 'cjs', + isProduction: true, + })); + // esm/[name].js + await addBuildTask(buildOrt({ + ...options, + outputBundleName: 'esm/' + options.outputBundleName, + format: 'esm', + })); + // esm/[name].min.js + await addBuildTask(buildOrt({ + ...options, + outputBundleName: 'esm/' + options.outputBundleName + '.min', + format: 'esm', + isProduction: true, + })); + }; - fs.writeFileSync(WASM_BINDING_THREADED_MIN_WORKER_JS_PATH, terser.stdout); - fs.writeFileSync(WASM_THREADED_WORKER_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + if (BUNDLE_MODE === 'node' || BUNDLE_MODE === 'prod') { + // ort.node.min.js + await addBuildTask(buildOrt({ + isProduction: true, + isNode: true, + format: 'cjs', + outputBundleName: 'ort.node.min', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', + }, + })); + } - validateFile(WASM_BINDING_THREADED_MIN_WORKER_JS_PATH); - validateFile(WASM_THREADED_WORKER_JS_PATH); - } catch (e) { - npmlog.error('Build', `Failed to run terser on ort-wasm-threaded.worker.js. ERR: ${e}`); - throw e; + if (BUNDLE_MODE === 'dev') { + // ort.all.js + await addBuildTask(buildOrt({ + outputBundleName: 'ort.all', + format: 'iife', + })); } - npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.worker.js"... DONE'); -} -npmlog.info('Build', 'Building bundle...'); -{ - npmlog.info('Build.Bundle', 'Running webpack to generate bundles...'); - const webpackArgs = ['webpack', '--env', `--bundle-mode=${MODE}`]; - if (ANALYZER) { - webpackArgs.push('--env', '-a'); + if (BUNDLE_MODE === 'perf') { + // ort.all.min.js + await addBuildTask(buildOrt({ + isProduction: true, + outputBundleName: 'ort.all.min', + format: 'iife', + })); } - if (FILTER) { - webpackArgs.push('--env', `-f=${FILTER}`); + + if (BUNDLE_MODE === 'prod') { + // ort.all[.min].js + await addAllWebBuildTasks({outputBundleName: 'ort.all'}); + + // ort[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort', + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true'}, + }); + // ort.webgpu[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.webgpu', + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, + }); + // ort.wasm[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.wasm', + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, + }); + // ort.webgl[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.webgl', + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true'}, + }); + // ort.wasm-core[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.wasm-core', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', + }, + }); + // ort.training.wasm[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.training.wasm', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_TRAINING': 'false', + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + }, + }); } - npmlog.info('Build.Bundle', `CMD: npx ${webpackArgs.join(' ')}`); - const webpack = spawnSync('npx', webpackArgs, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER}); - if (webpack.status !== 0) { - console.error(webpack.error); - process.exit(webpack.status === null ? undefined : webpack.status); + + if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { + await addBuildTask(buildTest()); + } + + await Promise.all(buildTasks); + + if (BUNDLE_MODE === 'prod') { + // generate package.json files under each of the dist folders for commonJS and ESModule + // this trick allows typescript to import this package as different module type + // see also: https://evertpot.com/universal-commonjs-esm-typescript-packages/ + await fs.writeFile(path.resolve(__dirname, '../dist/cjs', 'package.json'), '{"type": "commonjs"}'); + await fs.writeFile(path.resolve(__dirname, '../dist/esm', 'package.json'), '{"type": "module"}'); } - npmlog.info('Build.Bundle', 'Running webpack to generate bundles... DONE'); } -npmlog.info('Build', 'Building bundle... DONE'); + +void main(); diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index f90f568879146..ee955ec8d4f17 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -51,6 +51,10 @@ Options: -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: + none (default) + gpu-tensor use pre-allocated GPU tensors for inputs and outputs + gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' *** Session Options *** -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. @@ -108,7 +112,8 @@ export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; - type BundleMode = 'prod'|'dev'|'perf'; + type BundleMode = 'dev'|'perf'; + type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } export interface TestRunnerCliArgs { @@ -124,22 +129,19 @@ export interface TestRunnerCliArgs { /** * Bundle Mode * - * this field affects the behavior of Karma and Webpack. + * this field affects the behavior of Karma and build script. * - * For Karma, if flag '--bundle-mode' is not set, the default behavior is 'dev' - * For Webpack, if flag '--bundle-mode' is not set, the default behavior is 'prod' - * - * For running tests, the default mode is 'dev'. If flag '--perf' is set, the mode will be set to 'perf'. - * - * Mode | Output File | Main | Source Map | Webpack Config - * ------ | --------------------- | -------------------- | ------------------ | -------------- - * prod | /dist/ort.min.js | /lib/index.ts | source-map | production - * node | /dist/ort-web.node.js | /lib/index.ts | source-map | production - * dev | /test/ort.dev.js | /test/test-main.ts | inline-source-map | development - * perf | /test/ort.perf.js | /test/test-main.ts | (none) | production + * Mode "perf": + * - use "dist/ort.all.min.js" as main file + * - use "test/ort.test.min.js" as test file + * Mode "dev": + * - use "dist/ort.all.js" as main file + * - use "test/ort.test.js" as test file */ bundleMode: TestRunnerCliArgs.BundleMode; + ioBindingMode: TestRunnerCliArgs.IOBindingMode; + logConfig: Test.Config['log']; /** @@ -326,7 +328,11 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - return {profilingMode}; + const validateInputContent = args['webgpu-validate-input-content']; + if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { + throw new Error('Flag "webgpu-validate-input-content" is invalid'); + } + return {profilingMode, validateInputContent}; } function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { @@ -416,6 +422,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); } + // Option: -i=<...>, --io-binding=<...> + const ioBindingArg = args['io-binding'] || args.i; + const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + throw new Error(`not supported io binding mode ${ioBindingMode}`); + } + // Option: -u, --optimized-model-file-path const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined; if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') { @@ -455,6 +468,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`); npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE'); return { @@ -467,6 +481,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig, profile, times: perf ? times : undefined, + ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'], optimizedModelFilePath, graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'], fileCache, diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a75321d45f1ef..74a03290332a8 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -28,6 +28,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Config', inspect(args)); + const DIST_ROOT = path.join(__dirname, '..', 'dist'); const TEST_ROOT = path.join(__dirname, '..', 'test'); const TEST_DATA_MODEL_NODE_ROOT = path.join(TEST_ROOT, 'data', 'node'); const TEST_DATA_OP_ROOT = path.join(TEST_ROOT, 'data', 'ops'); @@ -257,7 +258,7 @@ async function main() { times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; + return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; } let modelUrl: string|null = null; @@ -323,6 +324,16 @@ async function main() { } } + let ioBinding: Test.IOBindingMode; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + ioBinding = 'none'; + } else { + ioBinding = args.ioBindingMode; + } + + npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -330,7 +341,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -390,6 +401,13 @@ async function main() { for (const test of tests) { test.backend = backend; test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + test.ioBinding = 'none'; + } else { + test.ioBinding = args.ioBindingMode; + } } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); @@ -436,9 +454,6 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(3/4) Running build to generate bundle...'); const buildCommand = `node ${path.join(__dirname, 'build')}`; const buildArgs = [`--bundle-mode=${args.env === 'node' ? 'node' : args.bundleMode}`]; - if (args.backends.indexOf('wasm') === -1) { - buildArgs.push('--no-wasm'); - } npmlog.info('TestRunnerCli.Run', `CMD: ${buildCommand} ${buildArgs.join(' ')}`); const build = spawnSync(buildCommand, buildArgs, {shell: true, stdio: 'inherit'}); if (build.status !== 0) { @@ -458,7 +473,14 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running tsc... DONE'); npmlog.info('TestRunnerCli.Run', '(4/4) Running mocha...'); - const mochaArgs = ['mocha', path.join(TEST_ROOT, 'test-main'), `--timeout ${args.debug ? 9999999 : 60000}`]; + const mochaArgs = [ + 'mocha', + '--timeout', + `${args.debug ? 9999999 : 60000}`, + '-r', + path.join(DIST_ROOT, 'ort.node.min.js'), + path.join(TEST_ROOT, 'test-main'), + ]; npmlog.info('TestRunnerCli.Run', `CMD: npx ${mochaArgs.join(' ')}`); const mocha = spawnSync('npx', mochaArgs, {shell: true, stdio: 'inherit'}); if (mocha.status !== 0) { @@ -493,19 +515,13 @@ async function main() { karmaArgs.push('--force-localhost'); } if (webgpu) { - if (browser.includes('Canary')) { - chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); - } else { - chromiumFlags.push('--enable-dawn-features=use_dxc'); - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } + // flag 'allow_unsafe_apis' is required to enable experimental features like fp16 and profiling inside pass. + // flag 'use_dxc' is required to enable DXC compiler. + chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); } if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } - if (config.options.globalEnvFlags?.webgpu?.profilingMode === 'default') { - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { diff --git a/js/web/script/tsconfig.json b/js/web/script/tsconfig.json new file mode 100644 index 0000000000000..23b1fc96eb558 --- /dev/null +++ b/js/web/script/tsconfig.json @@ -0,0 +1,6 @@ +{ + "extends": "../../tsconfig.tools.json", + "compilerOptions": { + "sourceMap": true + } +} diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000000..bd4483027cc25 --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/batch-norm.jsonc b/js/web/test/data/ops/batch-norm.jsonc new file mode 100644 index 0000000000000..4ea16f290dc8f --- /dev/null +++ b/js/web/test/data/ops/batch-norm.jsonc @@ -0,0 +1,446 @@ +[ + { + "name": "BatchNormalization with no attributes", + "operator": "BatchNormalization", + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,3,4,4,4]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425, 0.168795, 0.740422, -0.377683, 0.432598, -2.07414, -2.85251, 0.273531, + 0.0532606, 1.31052, -0.769382, 0.9976, 0.850536, -1.53812, -0.00496016, 0.931242, 0.0517056, -0.497829, + 0.275869, 0.860001, 1.23747, 0.179686, 1.5914, 0.740327, 0.798208, 2.12478, 1.74205, -0.322054, + -0.0112451, 0.204525, -0.431252, -1.3114, 0.186204, 0.780569, -1.42994, 1.63344, -0.00839034, -0.187035, + 1.8406, 1.32053, -0.636963, 0.408944, -1.50846, -1.2076, -0.129118, -0.0441307, 1.47558, 1.07251, 1.05295, + -0.420297, -1.13402, -0.524053, 3.20754, -0.588935, -0.527549, 0.591928, -1.10529, 0.520412, 0.19404, + -1.21229, -0.399594, -0.280935, -0.363324, -0.00804771, 1.43102, -0.523222, 1.17608, -0.53195, 0.914993, + 2.69308, -0.517211, 0.472273, -0.464725, -0.929768, -0.631145, 0.919709, -0.27391, 1.76689, 0.894897, + 0.235798, 1.2544, 0.858985, -0.139707, 0.354544, 0.200878, 0.353255, 0.0722632, -1.56074, 1.03685, + 1.73434, 0.193269, -0.864609, 0.842739, -0.372717, 0.584484, 0.16315, 1.60674, -0.0611289, -1.24544, + 1.33361, -0.961942, -0.15732, -0.348637, 0.361842, 0.7386, 0.517256, 1.20406, -2.07277, -1.01983, -1.9163, + 0.239934, 0.177979, 0.464564, 0.988822, 0.284607, -1.56099, -0.429143, 0.111043, -0.0853688, -0.319176, + -0.279777, 0.520971, -1.078, -0.670242, 0.065652, 0.468538, -0.825062, 0.370068, 1.68751, -1.16928, + -0.411782, 1.61624, -0.973004, 2.64703, -0.220014, -1.43954, -0.018692, 1.34982, -0.95197, -1.72586, + 1.32725, 0.280984, 0.00847463, 0.512869, 0.0378154, 0.13898, 0.35758, -0.084558, 1.04045, -1.79933, + 1.3002, 0.390457, 1.22267, 0.959344, -0.964296, -0.0935597, 0.288953, -0.158046, 0.532672, -0.500988, + 0.25187, -2.14384, -0.633315, 1.24612, -1.41525, 0.36494, -0.00714732, -0.608963, 0.508496, 0.995365, + 1.21159, -0.169055, -0.968783, 1.52779, -0.082381, 2.2049, 0.928655, 0.120245, 0.911429, -0.885258, + -1.2072, 0.770694, 2.36621, 1.08456, -1.60069, 0.0345025, 0.359559, -0.785411, 0.466532, -0.78543, + 0.024879, 1.59337, 1.13718, -1.27073, -0.263788, -1.7702, 0.203263, 1.34631, 1.11914, -2.04911, -0.804137, + 0.466763, 2.18386, 1.4689, 0.898297, -0.648948, 0.252202, 1.12501, -0.204563, 0.124608, 0.377214, + 0.894327, -0.249118, 0.709188, 0.999397, -1.4079, 0.193873, 0.657753, -0.709732, 1.09897, -0.145793, + 0.779199, 0.88378, -1.2676, 1.15709, 0.62295, -0.370894, -0.103268, -1.55949, -0.470747, 0.100394, + 0.422334, -0.0685312, -0.434488, -0.568974, -0.256987, 2.01276, -0.923322, -0.613144, 1.50676, 0.65756, + 1.20524, 1.10395, -0.975241, 2.44035, 1.08276, 0.330393, -0.508918, -1.25545, 0.189815, -0.156263, + -0.960866, 1.0859, -0.674478, 2.76743, 1.21399, 1.71666, -1.73198, -1.1062, 0.951285, -0.713336, 1.61586, + 1.96514, 0.002603, 0.0953297, 0.949256, -1.76552, 0.372816, -0.781229, 1.50532, 1.28462, 1.31116, + 0.731908, 1.54835, 0.371081, 0.409244, -0.106938, -1.79396, -1.61198, -0.80869, -1.10381, 1.1872, + -0.832439, 0.0755941, -1.09553, 0.960059, 1.44252, -0.196482, -1.07364, 0.165547, 0.630078, 1.56569, + -0.669592, 1.15974, 0.0953399, -0.202313, 0.812631, -0.318567, -0.16644, 0.887062, -0.0264821, -0.740725, + 0.0797577, -1.1037, 0.90236, 1.13427, 0.364186, -2.01043, -0.415748, 0.116046, 0.369949, 0.317886, + 0.530332, 1.48341, 0.74666, -1.64142, 0.22569, 1.18015, 1.31827, -1.33904, -0.101125 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189, 0.162177, 0.711393, -0.362876, 0.415637, + -1.99282, -2.74067, 0.262807, 0.0511725, 1.25914, -0.739217, 0.958488, 0.817189, -1.47782, -0.00476569, + 0.894731, 0.0496784, -0.478311, 0.265053, 0.826283, 1.18895, 0.172641, 1.52901, 0.711301, 0.766913, + 2.04147, 1.67375, -0.309427, -0.0108042, 0.196507, -0.414344, -1.25999, 0.178903, 0.749965, -1.37387, + 1.5694, -0.00806138, -0.179702, 1.76844, 1.26875, -0.61199, 0.392911, -1.44932, -1.16025, -0.124055, + -0.0424004, 1.41773, 1.03046, 1.01167, -0.403818, -1.08956, -0.503507, 3.08178, -0.565845, -0.506866, + 0.56872, -1.06196, 0.500008, 0.186433, -1.16476, -0.383928, -0.269921, -0.349079, -0.00773219, 1.37492, + -0.248386, 0.558316, -0.25253, 0.43437, 1.27847, -0.245533, 0.2242, -0.220617, -0.441384, -0.29962, + 0.436609, -0.130032, 0.838785, 0.424829, 0.111939, 0.595496, 0.407781, -0.0663221, 0.168311, 0.0953618, + 0.167699, 0.0343051, -0.74092, 0.492219, 0.823334, 0.0917494, -0.410451, 0.400069, -0.176938, 0.277469, + 0.0774512, 0.762761, -0.0290194, -0.59124, 0.6331, -0.456657, -0.0746837, -0.165507, 0.171775, 0.350631, + 0.245554, 0.571595, -0.983996, -0.484139, -0.909715, 0.113902, 0.0844908, 0.22054, 0.469418, 0.13511, + -0.741041, -0.203725, 0.0527148, -0.0405267, -0.151521, -0.132817, 0.247318, -0.511752, -0.31818, + 0.0311666, 0.222426, -0.391677, 0.17568, 0.801104, -0.282569, -0.0995112, 0.39058, -0.235136, 0.639682, + -0.0531687, -0.347878, -0.0045171, 0.326198, -0.230053, -0.41707, 0.320744, 0.0679025, 0.00204798, + 0.12394, 0.00913847, 0.0335859, 0.0864127, -0.0204343, 0.251436, -0.434827, 0.314206, 0.0943579, 0.295471, + 0.231835, -0.233032, -0.0226096, 0.0698283, -0.0381934, 0.128725, -0.121069, 0.060867, -0.51808, + -0.153047, 0.301137, -0.342009, 0.0881915, -0.00172722, -0.147162, 0.122883, 0.24054, 0.292792, + -0.0408538, -0.234116, 0.369206, -0.0199082, 0.532835, 0.224419, 0.0290583, 0.220256, -0.213931, + -0.291733, 0.186246, 0.571817, 0.262095, -0.386822, 0.00833788, 0.086891, -0.189802, 0.112742, -0.189807, + 0.00601226, 0.385054, 0.274811, -1.22091, -0.253445, -1.7008, 0.195294, 1.29353, 1.07526, -1.96877, + -0.772609, 0.448463, 2.09824, 1.4113, 0.863078, -0.623505, 0.242314, 1.0809, -0.196543, 0.119722, + 0.362425, 0.859263, -0.239351, 0.681383, 0.960214, -1.3527, 0.186272, 0.631964, -0.681905, 1.05588, + -0.140077, 0.748649, 0.84913, -1.2179, 1.11172, 0.598526, -0.356353, -0.099219, -1.49835, -0.452291, + 0.0964582, 0.405776, -0.0658444, -0.417454, -0.546667, -0.246911, 1.93385, -0.887121, -0.589104, 1.44769, + 0.631779, 1.15798, 1.06067, -0.937005, 2.34467, 1.04031, 0.31744, -0.488965, -1.20623, 0.182373, + -0.150136, -0.923194, 1.04332, -0.648034, 2.65893, 1.1664, 1.64935, -0.822216, -0.525139, 0.451599, + -0.338638, 0.767087, 0.932899, 0.00123571, 0.0452554, 0.450635, -0.838136, 0.176985, -0.370868, 0.714614, + 0.60984, 0.622438, 0.347455, 0.73504, 0.176161, 0.194278, -0.0507662, -0.851639, -0.765246, -0.383905, + -0.524005, 0.563593, -0.395179, 0.0358864, -0.520076, 0.455763, 0.684801, -0.093275, -0.509682, 0.0785892, + 0.299113, 0.743272, -0.317872, 0.550556, 0.0452602, -0.0960432, 0.385776, -0.151232, -0.079013, 0.42111, + -0.0125717, -0.35164, 0.0378629, -0.523955, 0.428372, 0.538468, 0.172888, -0.954402, -0.197366, 0.0550898, + 0.175624, 0.150908, 0.251761, 0.704209, 0.354458, -0.779221, 0.107141, 0.560244, 0.625814, -0.635675, + -0.0480064 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization with no attributes - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 12 }, + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,4,4,3]", + "inputs": [ + { + "data": [ + 2.02384, 0.168795, -0.523222, -0.935186, 0.740422, 1.17608, 0.488569, -0.377683, -0.53195, -0.513934, + 0.432598, 0.914993, -1.27082, -2.07414, 2.69308, -0.131913, -2.85251, -0.517211, -1.806, 0.273531, + 0.472273, -0.37904, 0.0532606, -0.464725, 0.667796, 1.31052, -0.929768, -1.14826, -0.769382, -0.631145, + 1.2522, 0.9976, 0.919709, 0.0300339, 0.850536, -0.27391, 2.4758, -1.53812, 1.76689, 1.55511, -0.00496016, + 0.894897, 0.385341, 0.931242, 0.235798, 1.46645, 0.0517056, 1.2544, -1.09355, -0.497829, 0.858985, + -2.56309, 0.275869, -0.139707, 0.976015, 0.860001, 0.354544, -1.47036, 1.23747, 0.200878, 0.89486, + 0.179686, 0.353255, 0.580989, 1.5914, 0.0722632, -1.12418, 0.740327, -1.56074, -0.339189, 0.798208, + 1.03685, 1.3314, 2.12478, 1.73434, 0.418893, 1.74205, 0.193269, -0.301401, -0.322054, -0.864609, -1.2983, + -0.0112451, 0.842739, -0.839063, 0.204525, -0.372717, 0.170261, -0.431252, 0.584484, 1.15486, -1.3114, + 0.16315, -0.255735, 0.186204, 1.60674, -0.589851, 0.780569, -0.0611289, -0.416289, -1.42994, -1.24544, + -0.952648, 1.63344, 1.33361, -0.360487, -0.00839034, -0.961942, 0.253287, -0.187035, -0.15732, 0.437195, + 1.8406, -0.348637, 0.32023, 1.32053, 0.361842, 0.209606, -0.636963, 0.7386, -0.279519, 0.408944, 0.517256, + -0.546527, -1.50846, 1.20406, 0.265286, -1.2076, -2.07277, -1.07383, -0.129118, -1.01983, -1.65879, + -0.0441307, -1.9163, 1.1222, 1.47558, 0.239934, 0.946612, 1.07251, 0.177979, 0.822549, 1.05295, 0.464564, + 0.64689, -0.420297, 0.988822, -0.292639, -1.13402, 0.284607, -0.73995, -0.524053, -1.56099, -0.694949, + 3.20754, -0.429143, 1.33899, -0.588935, 0.111043, -0.0652476, -0.527549, -0.0853688, 1.61791, 0.591928, + -0.319176, 1.49692, -1.10529, -0.279777, -0.761145, 0.520412, 0.520971, -0.201874, 0.19404, -1.078, + -1.15431, -1.21229, -0.670242, -1.83111, -0.399594, 0.065652, -0.705267, -0.280935, 0.468538, -0.143026, + -0.363324, -0.825062, -0.129819, -0.00804771, 0.370068, -0.799425, 1.43102, 1.68751, -1.16928, -1.27073, + -1.73198, -0.411782, -0.263788, -1.1062, 1.61624, -1.7702, 0.951285, -0.973004, 0.203263, -0.713336, + 2.64703, 1.34631, 1.61586, -0.220014, 1.11914, 1.96514, -1.43954, -2.04911, 0.002603, -0.018692, + -0.804137, 0.0953297, 1.34982, 0.466763, 0.949256, -0.95197, 2.18386, -1.76552, -1.72586, 1.4689, + 0.372816, 1.32725, 0.898297, -0.781229, 0.280984, -0.648948, 1.50532, 0.00847463, 0.252202, 1.28462, + 0.512869, 1.12501, 1.31116, 0.0378154, -0.204563, 0.731908, 0.13898, 0.124608, 1.54835, 0.35758, 0.377214, + 0.371081, -0.084558, 0.894327, 0.409244, 1.04045, -0.249118, -0.106938, -1.79933, 0.709188, -1.79396, + 1.3002, 0.999397, -1.61198, 0.390457, -1.4079, -0.80869, 1.22267, 0.193873, -1.10381, 0.959344, 0.657753, + 1.1872, -0.964296, -0.709732, -0.832439, -0.0935597, 1.09897, 0.0755941, 0.288953, -0.145793, -1.09553, + -0.158046, 0.779199, 0.960059, 0.532672, 0.88378, 1.44252, -0.500988, -1.2676, -0.196482, 0.25187, + 1.15709, -1.07364, -2.14384, 0.62295, 0.165547, -0.633315, -0.370894, 0.630078, 1.24612, -0.103268, + 1.56569, -1.41525, -1.55949, -0.669592, 0.36494, -0.470747, 1.15974, -0.00714732, 0.100394, 0.0953399, + -0.608963, 0.422334, -0.202313, 0.508496, -0.0685312, 0.812631, 0.995365, -0.434488, -0.318567, 1.21159, + -0.568974, -0.16644, -0.169055, -0.256987, 0.887062, -0.968783, 2.01276, -0.0264821, 1.52779, -0.923322, + -0.740725, -0.082381, -0.613144, 0.0797577, 2.2049, 1.50676, -1.1037, 0.928655, 0.65756, 0.90236, + 0.120245, 1.20524, 1.13427, 0.911429, 1.10395, 0.364186, -0.885258, -0.975241, -2.01043, -1.2072, 2.44035, + -0.415748, 0.770694, 1.08276, 0.116046, 2.36621, 0.330393, 0.369949, 1.08456, -0.508918, 0.317886, + -1.60069, -1.25545, 0.530332, 0.0345025, 0.189815, 1.48341, 0.359559, -0.156263, 0.74666, -0.785411, + -0.960866, -1.64142, 0.466532, 1.0859, 0.22569, -0.78543, -0.674478, 1.18015, 0.024879, 2.76743, 1.31827, + 1.59337, 1.21399, -1.33904, 1.13718, 1.71666, -0.101125 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, 0.162177, -0.248386, -0.225997, 0.711393, 0.558316, 0.118068, -0.362876, -0.25253, -0.124197, + 0.415637, 0.43437, -0.307105, -1.99282, 1.27847, -0.031878, -2.74067, -0.245533, -0.436439, 0.262807, + 0.2242, -0.0915989, 0.0511725, -0.220617, 0.16138, 1.25914, -0.441384, -0.277489, -0.739217, -0.29962, + 0.302606, 0.958488, 0.436609, 0.007258, 0.817189, -0.130032, 0.598301, -1.47782, 0.838785, 0.375807, + -0.00476569, 0.424829, 0.0931215, 0.894731, 0.111939, 0.354382, 0.0496784, 0.595496, -0.264267, -0.478311, + 0.407781, -0.619395, 0.265053, -0.0663221, 0.235864, 0.826283, 0.168311, -0.355328, 1.18895, 0.0953618, + 0.216252, 0.172641, 0.167699, 0.140402, 1.52901, 0.0343051, -0.271669, 0.711301, -0.74092, -0.0819684, + 0.766913, 0.492219, 0.321747, 2.04147, 0.823334, 0.10123, 1.67375, 0.0917494, -0.0728365, -0.309427, + -0.410451, -0.313746, -0.0108042, 0.400069, -0.202768, 0.196507, -0.176938, 0.0411454, -0.414344, + 0.277469, 0.279085, -1.25999, 0.0774512, -0.0618009, 0.178903, 0.762761, -0.142543, 0.749965, -0.0290194, + -0.1006, -1.37387, -0.59124, -0.230217, 1.5694, 0.6331, -0.0871152, -0.00806138, -0.456657, 0.0612094, + -0.179702, -0.0746837, 0.105652, 1.76844, -0.165507, 0.0773867, 1.26875, 0.171775, 0.0506533, -0.61199, + 0.350631, -0.0675486, 0.392911, 0.245554, -0.132074, -1.44932, 0.571595, 0.064109, -1.16025, -0.983996, + -0.259501, -0.124055, -0.484139, -0.400863, -0.0424004, -0.909715, 0.271191, 1.41773, 0.113902, 0.228758, + 1.03046, 0.0844908, 0.198777, 1.01167, 0.22054, 0.156327, -0.403818, 0.469418, -0.0707191, -1.08956, + 0.13511, -0.178816, -0.503507, -0.741041, -0.167941, 3.08178, -0.203725, 0.323581, -0.565845, 0.0527148, + -0.0157677, -0.506866, -0.0405267, 0.390985, 0.56872, -0.151521, 0.361745, -1.06196, -0.132817, -0.183938, + 0.500008, 0.247318, -0.0487849, 0.186433, -0.511752, -0.27895, -1.16476, -0.31818, -0.442507, -0.383928, + 0.0311666, -0.170435, -0.269921, 0.222426, -0.0345637, -0.349079, -0.391677, -0.031372, -0.00773219, + 0.17568, -0.193189, 1.37492, 0.801104, -0.282569, -1.22091, -0.822216, -0.0995112, -0.253445, -0.525139, + 0.39058, -1.7008, 0.451599, -0.235136, 0.195294, -0.338638, 0.639682, 1.29353, 0.767087, -0.0531687, + 1.07526, 0.932899, -0.347878, -1.96877, 0.00123571, -0.0045171, -0.772609, 0.0452554, 0.326198, 0.448463, + 0.450635, -0.230053, 2.09824, -0.838136, -0.41707, 1.4113, 0.176985, 0.320744, 0.863078, -0.370868, + 0.0679025, -0.623505, 0.714614, 0.00204798, 0.242314, 0.60984, 0.12394, 1.0809, 0.622438, 0.00913847, + -0.196543, 0.347455, 0.0335859, 0.119722, 0.73504, 0.0864127, 0.362425, 0.176161, -0.0204343, 0.859263, + 0.194278, 0.251436, -0.239351, -0.0507662, -0.434827, 0.681383, -0.851639, 0.314206, 0.960214, -0.765246, + 0.0943579, -1.3527, -0.383905, 0.295471, 0.186272, -0.524005, 0.231835, 0.631964, 0.563593, -0.233032, + -0.681905, -0.395179, -0.0226096, 1.05588, 0.0358864, 0.0698283, -0.140077, -0.520076, -0.0381934, + 0.748649, 0.455763, 0.128725, 0.84913, 0.684801, -0.121069, -1.2179, -0.093275, 0.060867, 1.11172, + -0.509682, -0.51808, 0.598526, 0.0785892, -0.153047, -0.356353, 0.299113, 0.301137, -0.099219, 0.743272, + -0.342009, -1.49835, -0.317872, 0.0881915, -0.452291, 0.550556, -0.00172722, 0.0964582, 0.0452602, + -0.147162, 0.405776, -0.0960432, 0.122883, -0.0658444, 0.385776, 0.24054, -0.417454, -0.151232, 0.292792, + -0.546667, -0.079013, -0.0408538, -0.246911, 0.42111, -0.234116, 1.93385, -0.0125717, 0.369206, -0.887121, + -0.35164, -0.0199082, -0.589104, 0.0378629, 0.532835, 1.44769, -0.523955, 0.224419, 0.631779, 0.428372, + 0.0290583, 1.15798, 0.538468, 0.220256, 1.06067, 0.172888, -0.213931, -0.937005, -0.954402, -0.291733, + 2.34467, -0.197366, 0.186246, 1.04031, 0.0550898, 0.571817, 0.31744, 0.175624, 0.262095, -0.488965, + 0.150908, -0.386822, -1.20623, 0.251761, 0.00833788, 0.182373, 0.704209, 0.086891, -0.150136, 0.354458, + -0.189802, -0.923194, -0.779221, 0.112742, 1.04332, 0.107141, -0.189807, -0.648034, 0.560244, 0.00601226, + 2.65893, 0.625814, 0.385054, 1.1664, -0.635675, 0.274811, 1.64935, -0.0480064 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode", + "operator": "BatchNormalization", + "opset": { "domain": "", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,1,2]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 1, 2], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 1, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,2,1]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/bias-add.jsonc b/js/web/test/data/ops/bias-add.jsonc new file mode 100644 index 0000000000000..e89c5dd81cc23 --- /dev/null +++ b/js/web/test/data/ops/bias-add.jsonc @@ -0,0 +1,874 @@ +[ + { + "name": "BiasAdd", + "operator": "BiasAdd", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias add [2,2,320]x[320]x[2,2,320]", + "inputs": [ + { + "data": [ + -0.43078827160569055, -1.3343044914862014, -1.3875186857333706, 0.7550491056943578, -0.015677493769832296, + 1.2761569600658005, -0.033221806310058, 1.3590872185896297, 1.706516873231478, 1.429932535224122, + -0.39598259309643424, -0.6735725816667406, 0.5551536414690581, 1.9642482005713608, 0.1481069760277398, + -0.6706041659325672, 1.7626582023316804, 1.8587314247747715, 0.6761988056588422, -1.2865903673587722, + 0.17451384248240753, -1.468579522325724, -1.4426371855482047, -1.101469412150644, 0.32065561169114787, + -1.7280217475756476, 0.8472414982170147, 1.089925472742105, -1.9188568763699578, -1.6897769809368715, + -0.22443847491362767, -1.592027916773227, -1.403545711986184, -0.992398507333391, 1.501364958055385, + 1.1969883087000177, -1.742362862568375, -1.3122082038099396, -1.602285316113413, 1.7122406231186638, + -1.7331729420495074, -1.1660018974894326, 1.8466385944978407, 0.8059189037532883, 0.42819875266849383, + 1.2622036788694384, 1.09253289978553, 1.7169775828443141, 0.8965924110901806, 1.731515754612638, + -0.710260858323644, 0.6949611010518701, 1.9392089898271783, -0.06855145845621369, 1.831602787077924, + 1.919258334566825, -0.1750228495679611, -1.385684174878624, 0.1864456393673155, -0.7627497248975219, + 0.5460459705476177, 1.0151263124631544, 1.0476468224745439, -1.4662684764200389, 0.8737230293064417, + -1.5018396956497346, 1.7171541562949724, 0.33089656971471637, -0.4497250989135697, 0.03155539207838842, + 0.051811401515028166, -0.6417182241860351, 0.2892848388089577, 1.7313934166665597, 0.07283972370973935, + -1.1648958342544793, 0.31235891981372976, 0.8667946196377363, -1.8955234844794324, -0.7935013829177411, + 1.7461436471789726, -0.7863083195781497, -0.24062975690975286, 0.5319240675275907, -1.093540712132092, + -0.08133035383567844, 0.24368896952864638, -0.7817379372747375, 0.09213174790435374, 0.7008130043325052, + 1.2549161788773802, -1.3617356636970648, 0.9203525826936687, -0.17473559280037954, 0.8615190853991157, + 1.7227407352985065, -1.6408918805405301, -1.2530895034442668, -1.561750342643168, 0.9122283638044753, + 1.52061857140668, -1.757751307093299, 1.9174275778315568, 0.8751522548565935, -0.40373986100628567, + -0.9081360776169474, -1.0200280593711373, 1.8684056043061243, -0.8656213251095854, 1.2016289009297738, + -0.6420074354987753, 0.8277984540094971, -0.7794691695697935, -0.3682767999575489, -0.8431108361591697, + -1.8337997630639924, -1.497840701527875, 0.3582228355620778, -1.7648364421707035, 1.2295209603744386, + -1.4630737828552771, -1.0901242069115051, 0.16944674169381635, -1.1367614122844483, 1.1108472254071025, + -1.823264607399035, -1.9393711989501217, 1.5655717488598917, -1.2791502509294475, 0.5264599738022175, + 1.2793776364048464, 0.6643881105536087, -1.8021867427977876, 1.2507197673418382, 1.5096918274472868, + -0.35787327298378635, -0.8142338946526149, 0.05554038558981045, -0.8292295885917778, -1.9802571383548457, + 0.15711486642700123, 1.0606485766817002, -0.75648306593627, 0.7821766109776371, 1.5576015081398582, + 1.274065256431233, -0.11137614589209388, 1.5523940512159609, -0.9225168834642474, -1.52372850136818, + -0.09285110141468689, 0.951570914266842, 1.4257064421442056, -1.1151817450836203, -1.0405979771695408, + -0.3512294660068509, 0.5544676624698717, -1.1641268738929513, -1.1024667332826468, -0.8470834386045532, + 1.0200780368213875, -0.4111438148101385, 1.5060564363688043, -0.9264136525377458, -1.1144625610145829, + -0.8955596783917263, 1.1298063812792858, 1.1669470152018224, -0.9624856700584665, -0.5105634488042536, + 0.3350150739180062, 1.8252865056498973, 1.8318837072890988, 1.247181799708187, 0.953253999223505, + 0.33583064317486144, 0.8060935063640473, 0.025351359439778953, -1.9020713785704064, -1.4395100399524523, + -1.4434709436802136, 1.1887525549807645, -0.7404500263025655, 0.43567229698745447, 0.697905733863994, + 1.4955234099092705, 1.8798334809531436, 0.28521714680739496, -1.8247616864327485, -0.8899636892784697, + -1.9902986052685518, -0.8612351129594646, 1.4473654540376284, -0.027164144754272534, -1.3409554794285556, + 1.8216393608213144, 1.6287547173387606, 0.34546830263617245, -1.8554217975954366, -0.1599936883574884, + -1.1517327823070804, -1.1043403244478194, 0.9959357690900248, 0.47246889924079394, 1.3040046448975406, + 1.3219564442035914, 1.5522794215260332, -1.3132538270303629, -0.6568357397541558, -0.3166961264888011, + 1.155856117058299, -0.2856093347506521, 0.37282953339474023, 1.3624325643069906, -0.5961360461760625, + 0.15430022597664106, 1.459554220213934, -0.4578169766295437, -1.5520231063597194, -0.9692019226485984, + -0.7868623417517249, -0.47648838559948903, 0.521518493577898, -0.43202750995667394, -0.2800695092149823, + -0.062073964072573595, -1.280928976710845, 1.653227202229826, -1.3266797422187144, -1.8026162116413413, + -1.6342686696555697, -1.8535208355884771, -0.6740602051435252, -1.1112082500330684, -0.2784013709506814, + -0.38477554625820254, -1.8030977138882127, 1.0792754568917307, 1.7171188079371378, 1.0025920437036717, + 0.4909327973461517, 1.7583392556519017, -0.9513053055982352, 0.9836640607205354, 0.676230592061307, + -1.532719596602865, 1.7094356339043584, 1.275239052465298, -0.7097409099122736, -0.07447816392320394, + 0.6054981041837371, -1.368606109962121, 0.5491138352978151, 0.9607716616110995, -0.8472145355857039, + -1.5962071764683028, -1.3945278279448647, 1.6804737734010837, -0.2799307641611124, -1.5199944467651827, + -0.6284155111805374, -1.387427601467464, -1.3689101475786654, 0.5479648937942034, -0.35058122336125397, + -0.01913895405776156, 1.9421209900010803, 1.521339125582216, 0.44317460300853284, 1.4673321860144632, + 1.0844686121268872, 1.5454854860881193, -0.10676223292802778, 1.304830685737219, 1.5144316149285784, + -0.9206202207515908, 1.800878209478248, -1.4022748170800003, 1.8475728691446207, -0.6049511260966813, + 1.3794138268485217, -0.5504308384418284, -1.308589620881369, -1.5124914949500843, -0.9843890898875198, + -0.47517812559034134, 0.019190610059508728, 1.3102818494016093, 0.02046481460734384, 0.8966955765481393, + -1.9544981815442322, 1.4818794231229235, 1.701164813487166, -0.1853726987447688, -0.9284129805631327, + -0.6665475855012435, 1.7174739766102922, -0.6866660477705144, 1.3741442823600236, 1.3144332615465366, + -1.6821345841035624, 1.4369774381581397, -0.40113703543299106, 1.2155358501458204, 0.3746014766056245, + -1.2797835093659478, -1.2733146708696594, 1.7592472733050606, 0.7411515442284751, -0.8417009900350942, + 0.43950532643238294, -0.4849697519720575, 0.4561330711561924, -0.9993716519294349, 1.2993450426953688, + 0.6077681589510648, -1.1884025219126162, 0.7423783440021925, 0.2164111466551839, -1.8404401979886122, + -1.6417250174486764, 0.33169566624656355, 1.4670345153781224, 0.05770292022609169, -0.7574653293187339, + -1.8332930679907182, 0.30545703812352265, 1.2645746651782588, 0.06851662509789058, -1.5007737660843672, + -0.7895770838504896, -0.16973917251686643, -1.1309776011957018, -0.5245832860094533, 1.2865060486729964, + -0.6877713804065921, 0.8603087489217494, -1.6760835939805823, -0.3421307621860663, -0.08163355960848584, + 0.8960283575910557, -0.07414831058651483, 0.06552900030788233, -0.28343554832982676, 0.8580233436387052, + 0.9756624200728901, -0.31454210863680565, -1.3961591961928228, 1.4357660237613095, 0.995011454280724, + 1.7963312095595976, 0.44298210926821735, -0.4481873251636479, 0.04661154364950004, -1.0205093044441567, + -1.416501713829339, -1.5609202237237856, -1.9678529588325686, 0.054492766820211536, 1.9234783545287293, + -1.730337423937696, 1.0002270533854967, 0.912626942921996, 1.6082280213507438, -0.0007848079092749316, + -0.16260961343095381, -0.4951675344800499, -1.358161813033476, 0.0952066467959174, 1.1451260564923738, + 0.5658192251004577, 1.638639222136547, 1.4561159943128157, 1.394720518360975, -1.6647338550020887, + -0.35227744791908844, -1.4997218325299349, 1.7269852905741923, -1.6265531305359326, 0.08177780426753678, + -0.2548418092940219, -0.5491200865930868, -1.3448390466679427, -1.4669459626844805, 1.6556888857778747, + 0.9797534441152536, 0.011435434178378223, -1.0444316805293496, -0.21951923449225497, 1.2775877064028922, + 1.3912532203430734, -0.41034439679528845, 1.6871573166908078, -0.2949753341249375, -0.19366027881779058, + -0.8063505536355873, -1.9683493413513462, 1.1343233715358245, 0.5863358659604163, -1.7161271586916902, + -1.1511237584065572, -1.026334237004506, 0.051935455424261256, -0.47999388673650234, 1.3653965992268642, + -1.6246601676055645, 1.0720439180368135, 1.7513891096415195, -1.9557074487732296, -0.6497504028268937, + 1.5657013366841293, 1.480374539039186, 0.6254160749725743, -0.1386074990036681, 1.4879074098885399, + 0.8608723264201785, 1.3707693807951804, -1.9599278528285744, -0.08178821222314703, -0.8801306486657889, + 0.2672127121047323, -1.0905213718834403, 1.2845923771997692, 0.7483674434994203, 1.692846569262815, + -1.4081461907911006, 0.4847990697153808, -1.384710972149036, -0.48477323237460457, 0.5075573472017414, + -0.469975647682892, -1.7330175330396296, 0.4900846176343876, -0.5504384887286493, -1.3105683633488505, + 1.7832520063588246, -0.008133111014310579, -0.4117632340667363, -1.364428903076842, 1.1502921666750634, + -1.297368143184646, 1.2657228492344146, -0.8339630526727912, 1.6036841813714835, 1.7388392748246462, + 1.0764044765977916, 0.8635247292187316, 0.8915437658945047, 0.048593859303998954, 0.6397243841879581, + -0.32613694518247716, -0.5352551370184315, 1.0102009591194312, 1.2995349433229801, 0.873915978492783, + 0.28688858188963984, 1.0299796543068158, -0.21558135465200134, -0.8175153872775942, -0.2228988257546698, + 1.3651108393401516, 1.7419897496979102, -0.4287431897946634, -1.7210342396761868, -1.1328295428945232, + 0.45940815613566865, 1.225728222018235, -1.8492110470394127, 0.11914217556997819, -0.5298005043189082, + 0.2548881953681379, 1.4159379393782983, 1.4882705620355736, -1.778746106600524, -0.1322461839301079, + 0.959787258509837, -1.2087764883807948, -0.656679223823172, -0.8052024037155077, 0.28857972466710535, + -1.9964605007110707, -1.618829946697912, 0.0003216720189040956, 1.3814498186817072, 1.267459171727996, + 1.7808899740756745, -0.026284995100580133, 1.8415787116931153, 1.588128771971907, -1.379239932131604, + 1.4954639033867787, -0.5365530743272542, -1.689396586297355, -1.5263365207258657, -0.15914036263121556, + -0.2236484993504826, -0.233339567785662, -0.6479528728470942, 1.7496861376407091, 1.9658328381866337, + 0.7465900143344948, 0.9538445157519684, -1.233501367636971, 1.7832767928842772, -0.05854857072903297, + -1.4741131369232807, -0.8878273465278168, 0.7857197910561684, -1.3138933804749309, -1.6218392819603826, + -1.5976871654455351, 0.7862199318904839, 0.7930452731942381, -0.3851655998300325, 1.454630867585033, + 0.07699835053088844, 1.9635265550996897, -0.9018052833921555, -1.3260562107679466, 1.591219679914813, + -1.9325066606486736, 1.6231422543543284, 0.06461353520395186, 1.938690006387282, 1.18237528878611, + 0.08913168691813134, -0.24910533784113476, -0.24270271174439095, -0.058892972285330636, + 0.3017252069250542, -1.6443343279597213, -0.35086444759952506, 1.2568748876859326, 1.2564694463685155, + 0.5001323821158774, 1.6753738435417835, 0.12536895878538168, 0.11445707906508318, -1.4674734226308148, + -1.664480357904826, 1.2455086575046002, 0.016219310045000768, 0.6915540154848818, 0.8863957336976913, + -1.902221644800517, 1.1943261373616156, 1.5343838485344943, 0.26951920872184143, 1.1650565743686334, + 0.8796643687219943, -0.9914368645961655, -0.4140197685853506, -0.31070218781242964, 0.18576121794650557, + -1.8127899542033798, 0.7364855692315055, 0.61436653795095, -1.340047295554938, -0.9787870418028497, + 1.467065674277963, -1.3989058830195544, 0.8760353392917564, 1.103715585163978, 1.5993888842634876, + -1.4487285544801765, 1.0735445052923094, -1.7221746626643117, 1.3406229033114787, -0.31808486639617684, + 0.18778714988957468, 0.3360116223452154, 0.9957693335536861, -0.7082173962327145, 0.13563127391102814, + -0.6103565638388746, -0.27366984218234336, 1.670182285266585, 0.7450288088790513, -0.7376055316855803, + -1.9155349843040552, -0.5647372929979877, 0.38985718472222164, -1.0859956865362559, -1.2774336101522819, + 1.6716179824020179, -1.8248740137755952, -0.2462417745015948, -1.6330831414998679, 1.2633365236447727, + 1.8405595764836615, -1.5041331151923663, 0.7795364178340263, 0.30228912276856246, -1.4199604845773672, + -1.9675070264789785, -1.5847567364669262, 0.8417404542568088, 1.411592224813667, -0.7798761695294161, + 0.3836369752883826, -1.8656778968593093, -0.5928302815127546, -1.9170044216421092, -0.05015509660959072, + -0.2949046704017606, 0.7706683291518104, -1.0318399043402815, 1.673370077717097, -0.014510791777835763, + -1.7993886007705928, -1.6310599365379153, 0.47517801288072725, 1.7067820519317198, 0.2398133204231856, + 0.8136128605181492, 1.7661699247824538, -1.7400434804598, 1.6440515677939507, 1.0630270197569356, + 1.604276559525923, -0.6932776253265054, -0.5527120052345769, -0.15050627035753017, 1.9862178971553677, + -0.015247839923115514, -1.973561045412188, -0.7527150187762626, -1.8069828707308844, -1.0877483004326844, + 1.5632037263398146, -0.7907173278699853, 0.020757273269920162, -0.9162761989634776, -0.871619440202827, + 0.9274209555443278, 0.2788106849606935, -1.139267389509799, 1.8554359980704245, -1.2900715232363025, + 1.5740910757851019, 1.2022687452616108, -1.3263358341556861, 1.705785771419051, -0.2966767936771886, + 1.0184489507711243, 0.9331775150786781, -0.09496306724418613, 1.6169644094277178, -0.43071270921050253, + 0.00617695499471882, 1.9046927055754255, -0.3800897035987143, -1.2677566100047573, -1.6177276055783159, + 0.41451247814580316, 1.9136674265154827, -1.4827469552090689, -1.7754603228705612, -0.6490018494468908, + 1.0832837177200183, 1.5236866225466352, -1.3663768537968632, 0.3345867923992154, -1.3264457751116838, + 1.1302163417527789, 0.7987367218594539, -0.16727587952795364, -0.7323977510251147, -1.3812007830618693, + 0.7210535269663234, 1.1937695850586403, 1.9659603230065574, -0.12903581925605057, 0.22599793243849042, + -1.055270780364089, 1.0726029698666473, -1.8921222720165147, 1.232307457573829, -1.6501289040499376, + -1.2203689042981933, -1.5782693671316723, 0.84090024006949, 1.1443588250664956, -0.8862909159261179, + -0.34912733758720993, 0.15172846490974035, 0.6801601864117464, -1.3948321257969702, 1.9080269171533608, + -1.851455291444804, -0.8525139187927957, -0.9629354426957466, -0.3970802170728076, 0.2714456125847784, + 1.9355349765888947, 1.643295056500194, 0.461049347109487, -0.819228054522112, -0.7773196244370615, + 0.27305821451894285, -0.5007383808686932, -0.5307070901422906, -1.6087255013287924, 1.708746273758031, + -0.18903643771546896, 0.4237658727537639, 0.7912222637914885, -0.06713232018529425, -0.9303076017910259, + -0.9715788400465986, 1.9239773722802864, -0.13605413264730704, 1.63369388301102, -1.4098364226383078, + -1.8177000155999794, -1.0786108587058756, 1.1467184875003493, 1.6942868827942474, 0.3743937965110735, + -0.5205992959118086, 1.9008470494268277, -0.33613881942363744, 1.408798355509151, 1.4374267676314432, + -1.6266431395044751, -1.2795792184706105, -1.6942133113310183, 0.44536668255477885, -0.5438457389607523, + -0.10839952389615792, -0.741753360770752, 1.9252204052044393, -1.298683868110981, -0.6674482773778925, + 1.7250903096066814, -0.31851968039415723, 0.09917784413898989, -1.2300134366181839, -1.6661221342983064, + -1.7392500850180967, 1.1202143474944348, -0.008191271742616912, -1.9527622539466778, -1.112511983718873, + 0.8023412387190376, -0.8667541685114779, 1.5324025634620666, -0.885176692529595, 0.28103075310068526, + 0.5555445946177473, 0.0746344154810279, 1.8279995093264354, 1.589298526844452, -1.1728977455798502, + -1.2266193029402643, -0.08324504041088154, -1.5997622864209005, -1.9929554997781622, 1.649795503298476, + -1.825621659313522, 1.2037940913995477, -0.1478578137836779, -1.791209833936434, -1.5875938503209248, + -1.54554126750838, 0.22767426425123372, -0.9730859655005233, 1.01498892879457, -0.14002274071109522, + 1.5196250773427282, -0.262915623574389, -0.8395076912698629, -1.6509617804598316, 0.4853743370864656, + -1.2426649457879062, -0.34746037671673236, -0.2668537063224701, -1.5727266735586252, -1.117657993368196, + 0.38564526206938243, 1.8759791120426446, -0.9537042932874478, 1.8291799127445367, -0.005069359588546263, + 0.9881621403906173, -1.0839433778149212, -0.4270854569893858, 1.6796287101836995, 1.4763471217127018, + -1.0142364314111942, -1.9083290608311998, 1.8813585092075078, -1.7461290703297596, -0.21674842672674544, + 0.5502685329995662, -0.4411322123830974, -1.8416200674737109, -1.0844672555196375, -1.4812521607836615, + 1.5083908944091222, -1.038918979055473, 0.7142655565242073, -1.1043893519874795, -0.9397073551320325, + 1.335714667963627, -0.39552262419583606, 0.941189318776745, -1.973265716530558, 0.29314288363790375, + 1.1149136203907473, -1.3280789878295156, 0.4344131789940944, -1.4976865930360805, -0.18683967016091696, + 1.6841429470215354, 0.9990587834251263, -0.9896683309850935, 1.7344823188063838, 1.2809894053746431, + 0.11253057753370044, 0.6202153889005659, -1.1393527545232924, -0.13267950871413792, 1.9929188473410813, + -0.5751171946115745, -1.5738159418363038, -0.722979035597497, -0.3845867260905491, 0.6726496775295967, + 1.5922485171567802, -1.2068498206133187, 0.412853384927236, -1.1136351543496117, 0.4046306514337976, + 1.2348205714395197, -0.6837477630512465, 0.4396481646202046, -1.689933367495156, 1.2059642129302333, + -0.9663178383992985, 1.1606541969900555, 1.6008140707565817, -0.49901900408756017, -0.3295636539441862, + -1.6961254597784983, -1.2594474125668986, -1.8290342261999246, -1.3195471501043112, -1.3360729012617636, + 0.4819319828437685, -0.6044650413524977, 1.5401325916637765, 1.3108212503564856, 0.6641610189431937, + -1.167107917049922, 1.3717452437078013, 1.8234968598766077, -0.059000791459610014, 0.5078759939367101, + 1.7012186950957178, -0.8153543329038886, -0.8116555600682265, 1.1042603281155614, -0.5230370384584662, + 1.5907666644047485, 1.3585126059484214, -1.1604013321546773, 0.2832653250904853, 0.6146831527317715, + 1.710136815942171, -0.4339250659200289, 0.5404568535663827, 0.9731252061576328, 1.667624932562064, + -0.8395294570873553, 0.9655900545408684, -0.8459497721445768, -0.3303605936459242, -1.0228351996527785, + 1.64618479653826, -0.17369144560780647, -0.762588585662475, -1.420812072676508, 0.04977508086731497, + 1.1411840603073538, 1.2658579855151144, -0.7695492083057207, -1.5802557368206935, 1.8063629418173504, + -1.3314629726856042, 0.926332179655863, -0.29403591774091886, -1.4368324532624133, 1.794172160269225, + -1.45190145484798, -0.6970484207532923, -0.7393596578156938, -1.3777134665955009, 1.698548607127142, + -1.5277458031231408, 0.8121596602029895, -0.4871450419173451, -0.8973481250486168, 0.999939728035903, + -1.622149524743996, -1.8097564354752569, -1.9184903777986992, 1.6699213455758075, -1.8966494026411178, + -1.5537605568683883, 1.0114197035541261, -1.7582102654477056, -1.881605728865896, 1.9424710063889181, + -1.387404030261334, 1.3858022761207254, 0.6698691050652563, -1.7882425787208467, 0.05463416162273482, + 1.151588572666963, 0.9448022669750591, 1.2079032520591921, 1.3271868932748552, 0.7459734492068639, + 1.4448931947680101, 1.741554075288172, -0.49778064496799157, -0.05357363118867564, 0.17060355537450267, + 1.8858952797885928, 1.1828850394664219, 1.4398264570692447, -1.0269803221490008, 0.12795611159499387, + 1.4338295119095292, 1.7680805378740985, 1.8889001943775678, -0.023646131688371597, -0.055364618226636075, + 1.3107732868213482, -1.3726761197824935, -0.48421640176631975, 1.8520978683112554, 0.14900451528494418, + 0.7553309487914097, 0.995210897988966, -0.2653148753757497, -1.0047335940870337, -0.15140716923573905, + 1.0342357378533045, 0.9590192011054128, 1.3276618340182669, 1.7076552004070518, 1.3639368693762144, + 1.0626034699464553, 1.0985888634186862, 1.0871213821052033, 0.3518298069849042, 0.6905794127769393, + -0.5700252629850588, -0.10814050178161683, 1.3965639143955952, 0.8292785089561896, 0.25327348015151596, + -1.334944218927732, 1.6209990328336517, -1.26244979705121, 1.771347153639546, -1.4436659851102362, + -0.5033590550326617, 1.484309499478445, -0.3804774758165417, -0.6854434358446646, 0.11814627802495625, + -1.8940220672649586, 1.9948327521339193, -0.31419774418955715, -1.708608376028823, -1.8717143806284637, + 1.4405268554284918, 1.8275766986420505, 1.613878636732296, -1.6842307903910925, 0.40437384436799473, + -1.7974817731693786, 1.1222020737272933, 0.8616912290576968, 0.7450260507858868, -1.262819663341098, + -1.3890825964448705, -0.24733521021608595, -1.9566763230316564, 1.240599031645397, 1.0005078570110628, + -1.2429059760645886, -0.8343788480444481, 1.8180187022655172, -1.2406155795725864, 1.206905860440954, + 1.0657535671563094, 1.0688089382594308, -0.1623146723818225, -0.9394369384107097, 0.1644095126054408, + -1.5836754669396766, -0.45886757503464093, 0.02309988717518241, -0.11609000844996142, -1.8627934732063016, + 1.9415986762986073, -1.1741977923279308, 1.2850766678166368, -1.6650362245910895, 0.8711235235579853, + 1.703790813181027, -0.20031635110543, 0.7709122587840396, 0.4676763407673441, 0.7333591027965438, + 0.010661459873729129, -0.6248209657856156, 1.3499622620431584, -1.010555890938674, -1.9094767924575482, + -0.4118954261411796, -1.9764407569645153, -0.11597198863706204, -1.2058671493925148, -0.9480128119239577, + -0.5293044156931037, -1.0802020289569132, -0.17428061346628443, 1.79479586490669, 0.4507608914027035, + -1.5890677193516893, -0.4180241158854212, 1.1247910122551152, -0.10769135533882057, 0.2413054436244062, + 1.3070197809453399, -0.7234463442247714, -0.9044481440724681, 0.808060474881219, -1.9681916392611978, + 0.6794353030118225, -1.8140413117066592, -0.21172484209703857, -0.3970612901969721, -0.22610168646442563, + 1.8446444889972504, -0.14161684848047962, -1.0612317380319158, 1.6805704263182024, -1.5680342684533937, + 1.6367583739045974, 1.9603810572547848, -1.9461850695662726, -1.669279137293902, -0.4582040515383534, + -0.4903387593204007, 1.700009791169296, -0.1006260106974528, 0.46356012096704813, -1.148937310426322, + 1.1291686959534264, -0.43326682883595513, -0.7294554760312604, -1.3404464141642505, 1.4428709283346048, + -0.31527585219642784, 0.5232849548965959, 0.840594052127317, 0.8605144687711537, -1.6471991161039679, + -0.32119284017874694, -1.199582906673192, 0.3080262169024959, 0.14151294810583348, -1.7354321287231, + 0.4873457081904098, 0.30837874931057474, 1.3003825539901728, 1.9636934942685267, -1.017158000827136, + -1.7596484196087827, -0.6817110071876389, 1.8933995404689652, -0.27989446483984093, -1.067564992804905, + 0.48291319348514605, -1.6740080386493696, 0.2807422201361458, -1.9110968101706796, 1.6831495761856532, + 0.11512823651793447, -0.01736208420134222, 0.4405414565435697, -1.8085718479527353, -0.9564467319845411, + -1.58340676850203, -0.647955866711154, 1.6543925512752926, -0.7724666857844449, 0.5113949921548908, + -0.39267895944443065, 0.40761631065773063, 1.7690968773142668, -0.21764226027462374, -0.5169606846714005, + 1.4412302687210286, -1.8896763215831802, 0.07979887381656514, -0.11436797170178448, 1.0634323241712869, + 0.26858767414261475, 1.582753510128553, -0.010528543666477042, -0.3613643892495162, 1.391514209117238, + -1.4700872595733046, -1.5362821122874086, -1.818586245442571, 1.9678697208900475, -0.9362374796105595, + -1.4709960962767532, -1.182374325374778, 1.6385607669867177, -0.46775448579892487, 1.9576437315276696, + 0.915531228777362, -1.0860235734385926, 1.0655104012081509, -1.7770877181643954, 0.1443659128293806, + -0.23955298993629803, 0.41725367891443366, -0.8558589757408512, 1.024674449305122, -0.8538581096220099, + -0.17121172366938264, -1.0495343198650096, -0.8461809157835463, 1.956660524400533, -0.10451516941234473, + -0.9119888509755709, 0.9341633453090434, 0.5765821236303488, -1.8017153374435075, -1.6959921212218267, + 0.3565838506194048, 1.986423720658717, -1.1810787364750697, -1.3554314442606277, 1.2292087344595828, + 1.9389467760629646, -0.06060251846881748, 0.6471281822482204, 1.028562584237319, -0.9889764039700069, + -0.30300382154607064, 1.8809742113734886, -1.7911374091327446, 0.4093234223382991, -0.692170253260544, + -0.5766217362114325, -1.234711065294488, 1.4845455677723791, 1.1142993730640143, 1.0351547495051978, + -0.2304804542756207, 0.7896680860540854, 0.7368394967498872, 0.6117647784314304, 0.9649509001647774, + -1.4794529756304886, -1.5330264541276408, -1.1347331780500776, -1.957296773370273, 1.0497217009949296, + -1.876577007676679, -1.8707142834400772, 1.0355676671507679, -1.3024864669068572, -0.8110172097035955, + -0.7956308468122133, 1.5651626086889294, -0.5950947090287055, -0.15363512205018193, -1.7408026469236138, + -1.8514840915078024, -0.40821529034130855, 1.2085979600022174, -0.953165571253348, -0.03303673441403454, + -1.6012777563202603, -1.1080907034689993, 0.4974859939502432, -1.2774517368694216, -0.20163785863448247, + 0.8261780324851822, -1.5127015843190126, 0.14100033563999403, -1.052646319316592, 1.6782929279009817, + -1.5464154280508744, 1.7486715792103427, 1.7780537663957405, 1.7209562957702067, -1.5054888719925499, + 1.5292916602749358, 1.179965119787405, 1.5170349093126205, 0.7433092687415837, 0.9522185073327041, + -1.2380413345480523, 1.9354870169011447, 0.625636243747028, -0.09000816987169546, -0.5335972012344188, + 0.9674628266238745, -0.03967494279717343, -1.8591428816092792, 0.2309446236016406, -0.9041639531030761, + -1.9026103934874206, 1.5541589920102101, -0.45696090245009824, 1.8466298423672463, -0.36055327204706167, + -1.2458073226198056, -1.4410345639464017, 0.4731557626110279, 1.2307498218360111, -0.30913913858563724, + 1.8557259220973865, 1.1724797822135766, -1.516241961681641, 1.147047572924638, 0.009295148811337306, + 0.8291735590935811, 0.9825251314963639, -1.6702374836146134, -1.3070146895265724, -0.35977729833032246, + 0.10882986028094521, -0.1545812635060546, 0.5966946312401102, -0.12463585998219351, 0.764026848253426, + -0.0653501987613172, 0.5337159207310522, -1.3783865008607394, 0.3440914524635428, -1.6128660868012537, + 1.0505520366072156, 1.4508195056160966, -0.3811605116562866, -1.0184337989154448, 0.3472432185953034, + 0.7934690008453043, -1.1871814411996455, 1.7160465328415073, -0.8932034391682153, 0.22684342695521842, + 0.006601468173127678, -0.8703158970162406, 0.7854001417512366, -1.6096149032006064, 0.5734105371918883, + 1.8323642183966413, 0.8494195484926621, 1.52159530384528, -0.43666400911265324, 0.6949749758230679, + 0.09014001060463439, 1.1181673671725294, 0.23797216532766896, 0.9091467606318648, -0.051242214293259813, + -0.3492957666583303 + ], + "dims": [2, 2, 320], + "type": "float32" + }, + { + "data": [ + 0.0018182522274203805, 0.06756509596322857, 1.889723294866065, -0.10289095754140298, 1.5711519216894745, + 0.027529292075774592, 0.9603256438495507, -1.497309631471758, 1.9251601219617065, 0.8851491878732389, + 0.05078780805071137, 0.40903741455911735, -1.6644840015459215, -1.348225759557871, 1.615832737926227, + 1.042719864089511, -1.9289326046242312, 1.3417535199012995, -1.710655801290117, -1.130165128147044, + -0.3755000776719024, 0.6155781582426902, -0.5883485771887473, -1.7159986811406176, -0.16333854572017525, + -0.06079239446971929, 1.6926064002585495, -1.8776332892248098, 1.4601970742576578, -1.3202352800423185, + -0.12899708506012164, -0.6003093613879029, -0.1726349092091164, 1.2394146364350664, -1.769629141089184, + 1.4197330981295524, -0.9504267735259635, -1.240675610662361, 1.4018548317486923, 0.5332018345413356, + -0.16073415033536875, 0.15303724703170385, -1.8037963193841238, -1.311714810716846, -0.5740602095553404, + -1.0372165240096223, -1.370949121899355, 0.29661966702940035, 0.07816374250571023, -0.41396787300651905, + -0.3694698645575212, 0.6759765867037197, 0.2952400780995772, -0.06275069272676337, -0.9130274419561628, + -1.8944701092982958, 0.33465806810173593, -0.404939193749847, 1.4043718178232805, -0.5590711165263631, + 1.2184926422968934, -0.7087036307842709, -1.6055109382118182, 1.968257767003597, 0.529695028652811, + -1.9967381817454308, 1.595078125176956, 0.9871155490120058, 1.6566751957870993, -1.609626148231829, + -1.1397801527001823, 0.02238544560446254, 1.4873497063814245, -0.4755743745599572, -1.6926423664304844, + -1.4161828320433028, 0.346372427157398, -0.18203459832580027, 1.4635583159542183, 1.5944148599650028, + 1.3186726267824955, 1.5675687012032968, -1.9754809408365706, 1.44557963549327, 0.9397875688795354, + 1.4424046221061442, -1.4458135310352649, 0.9975520389856136, -1.9027511578082317, 0.9144382540308484, + 1.052124261689804, -1.4732678674195494, 0.29024955712503164, -1.2231252144665383, 0.34787712508784985, + 0.3556934319800238, -1.7419738471239645, 0.8630538908485903, 0.5386452782458866, 1.7600516786463105, + 1.8905437777505014, -1.5744952794523028, -0.7530004157782235, 1.5678919268380707, 0.034533101389558674, + 0.7325333516090975, 0.9775064333478163, -1.6408433791748216, -0.7414398323785214, 1.6725433719876586, + 1.0072882099919305, 1.4341931058179327, 0.7139948421146176, 0.40545031341822124, -0.11478362063979386, + -0.9345270890825441, 1.4281286745225614, -1.39970554180245, -1.1485396410325386, 1.1495990036520798, + -1.2916127094423402, 1.4211660589871826, -1.0749317173140405, 1.4370776307284663, 0.7880288709576773, + 0.46732661965227873, 1.56798877542517, -0.9531716195760707, -1.3739051298849194, 0.9766290318098436, + -1.307661662111708, 1.8574559170417002, 1.35797073743995, -1.6940130226054606, 0.28491131826133387, + -0.36419491260352554, -1.3047662545854015, -0.9266815176320033, -1.2358711507932467, 0.9127887752631247, + -1.6466848327495578, 1.1607458121027339, 0.46297657760513733, -1.5495718508374514, 0.3292413137438217, + 0.7675934897387728, 1.4008121445440214, 1.4898570624591958, 1.6030744917802648, 1.2925872420362232, + -0.8421561750911684, -0.3407292616133608, 0.38924919209979336, 1.6793513775487527, -1.0373013949726966, + -1.5337353736283532, 1.5143316995909872, -1.6320472160478126, -1.3996482770156646, 0.6337864872715988, + 0.5406528347636357, 1.2967809902878562, -1.5182702863754916, -0.7399098341126589, 0.31978027899894723, + -0.4320909370805026, -0.057815767103424065, -0.42656779912656795, -0.7191461156604344, 1.732444695508872, + -0.16793165663622744, -1.029044319841585, -0.7379183254565955, 0.6335667491493258, -0.7407757474651113, + 0.737814588729532, -0.41713542698826167, -0.862992043249343, -1.8537968903371889, -0.480058608858549, + 0.04028745468513595, -0.11696118988455151, 1.2159286584219329, -1.4551651039165874, 1.8518920420484895, + 0.8324620148383071, -1.9503205997190287, 1.3118092522348013, -0.41781057862944326, 0.47025354333711356, + -0.08599400306878557, -1.398138636933056, 1.799030968066016, -0.9016154689967486, -0.44642885397970034, + -1.6161407274817075, 0.6108393015698415, -0.9652371448534662, 1.472448459030451, -0.12097411552763226, + -1.7427779621544364, 0.6772588555443013, 1.239525535102806, 1.4978793781566582, 0.9794171716198061, + 0.37480400234555056, 1.7099069435864092, 0.5339030487857208, -1.7368267186422761, 0.3401246801395308, + -1.495349576003802, 1.1154539341471592, -0.5739747352480027, 1.7719108709631328, 0.7087464471791378, + -1.407094251765498, -0.5711994993106657, 1.6197007171162792, -1.665245693725593, -1.4093290138388097, + 0.8150971020478908, 1.1565262598728276, 0.007036682898540647, -1.067724969488646, 1.1760370444772006, + -0.4660822995530971, 0.18663889657333232, 0.8600384570962394, 0.07639203983671461, 1.7055162765205303, + -1.7134292208088802, -1.3413558800873675, -1.338677372528159, 1.4246968540400653, -1.1823984287999973, + 1.4751654585472211, 0.5262834049380078, 1.5117343050060867, -1.409416488118043, -0.39544742603356386, + -0.6577586706586658, -0.5919201797053688, 0.6013534842506445, -1.1862135968111707, -1.229417973714626, + -1.803412419156234, -0.7655790098575235, 0.9128632433156794, 0.9036623476529559, -1.9831271121679324, + 0.8324308647368319, -1.759507307385337, -1.7725931616787687, -1.7039303423725647, -1.4439967872268928, + -1.7432455401143834, -0.02216033991501387, 1.2819676717165, 0.16659457648361364, -1.167642388668959, + -1.7143084152722228, -0.7289345444538382, -0.02925241516287791, 1.9566358667801342, -1.2857581699546135, + 1.0915031830445114, 0.05084200795390714, 1.083568818422366, -1.1315486700234478, -0.8881346175534794, + -0.63619987674788, 0.3799832019858531, 0.2477670922101094, -0.6132210208290614, -1.8451948781462812, + -0.22847217268867048, 0.0025467735349682386, 0.1315834394384794, 0.1776926575489597, -0.8691295174311664, + 1.6637912565242994, 0.448901769947029, 1.233816013145204, 1.7971799993597228, 1.8719614934816882, + 1.655937636621596, -0.27359273976124054, 0.08461142131684696, -0.2757947346097396, -0.9521228519499276, + 1.766034536643284, 0.8831916052200137, 0.9813027219865562, -0.322591101625501, -0.20675723380495992, + 1.0866641329284041, -0.6397672290782843, -1.9715973970816654, -0.36395252045986304, -1.4160336028155198, + -0.7487477697571272, 1.4091533113140509, 1.2152244001439598, 1.0139512253023701, -0.5841820989850488, + 0.36171343432432046, 1.1810326691265303, -0.044977125366693294, -1.5719763377131732, 1.636814383280785, + -0.8254090686593019, -0.2739258751225844, -1.5838736296117837, 0.057544692367468286, -1.6536791042504957, + 0.8676152862870037, -0.6012988236535559, 1.0789190140651197, -0.21655562768188386, 1.5865699400089461 + ], + "dims": [320], + "type": "float32" + }, + { + "data": [ + 1.535672932186043, -0.3469466691127403, 0.7594896463626952, -0.05450122463129414, 1.4639377922956394, + -0.6333990278356465, -0.8242789470237648, 0.5117653543833605, 1.6078505759993273, -1.410275750604895, + -1.6792951377646883, -1.783057576321041, 1.1956662347204423, -1.3979831191193002, 1.7644067312268517, + 0.4240762243207543, 1.986096182518743, 0.36545941859180964, -0.8774236745093011, -0.8647372274160503, + 0.8720148666725347, -1.022106286236455, 0.5503111675120635, 1.0204841436521281, -0.9254965061314904, + -1.3449432022823808, 0.006824458535456657, 0.07690008991648423, -0.8426817383905396, 0.9996621329373534, + -0.23056243949407484, -1.0440039859718286, 1.9168909615768683, 1.5600000104620682, -1.9890822883775865, + -0.3604004168107382, -0.028511959235538065, 0.8476098198214288, 0.22053970034789216, -0.42929632097288817, + 0.6599479925924321, 0.6647860485919495, 0.10175396167639938, 0.22650892002231515, 0.4701540897019987, + 0.624214514356682, -0.6652257805050041, 0.8518349008799753, -0.9562813618340789, -1.657496508881473, + -0.3312572279583206, -1.5494034812904562, 0.18877981986543801, -1.2351800795813066, -0.07918559380797063, + -0.09391536586009241, -0.2856357420391582, 1.9393958604954182, -0.7529216437305211, 1.525084648903749, + -0.07883509109638975, 1.376637107607113, 0.5783362536287875, 0.961847664027677, -1.6855455725917468, + -0.5830772019897683, 0.4271291901307981, -1.8373857521152086, -1.7965394924729141, 1.7115697467771378, + 0.2565457539488545, -1.3360260284983019, 0.4353676471582455, 1.7248708601658969, -0.9750598890096729, + 0.05312641822767361, 1.5787554531472985, 0.9667162219022503, -1.364971428290251, -0.2814850946411962, + 0.9013643208289279, -1.4725055888862961, 0.6001425944665559, 1.2723681158746203, 1.7714493392964075, + 1.4044899825398272, -1.5787548082153382, -1.1589036159974757, -0.4012478414167475, -0.06868641055197777, + -0.7534521481998526, -0.8700101449208493, -1.1662115104665567, 1.7611310737805477, -0.2501517942331226, + 0.12866215308587936, -1.4699875001512854, 1.0395395370450604, -0.5782390952646876, 0.63115653417037, + 0.10138581116634082, -0.07007439881121424, -0.4276277546360472, 0.418589841403306, 0.9267207479900215, + -1.0293664343515356, -0.1495871781602336, 1.4452889339030666, -1.7189823464809564, 1.8323799237149645, + 1.8914008693919682, 0.3829486757403364, -0.8735369861149813, 1.602486711188683, -0.39959917784662924, + -0.8673792916868024, -1.2627215362178648, 1.8597348040684398, -0.8688156300975107, 0.15713415561611388, + -0.13148226217512082, 1.883732805180382, 0.11420203807616502, -1.6552288945493094, 1.0335466032430753, + 1.9806710089769703, 1.988269693866676, 0.29427412741632075, 1.4966799360753749, 0.6937827119996989, + -1.298620046493725, -1.752952308784005, 0.46645438478103873, -0.898908219432915, 0.7139098459371658, + -0.16215199540773462, 0.07954281050960432, 0.795652990025399, 1.5967568490712063, -1.2445652996859247, + 1.9127555713254205, -0.4996844935898572, -1.1156627480592256, -1.2948343944985163, 0.18276720875230268, + -0.748683470251498, -1.5079466014120557, -0.2494558107532141, -0.9231537960141623, -1.4121241243829443, + 1.2059834829573104, 1.905725511300579, -0.39337905860681044, 1.8425190053973166, -1.6566221588247219, + -0.242919176072947, 1.2425502129492436, -1.4417507121400348, 0.015600407032383856, -0.2242098694907284, + 0.18796276556529357, 0.08107732342066765, 0.7149451467441841, -0.20769007081368773, 0.4421202004832834, + 1.5233025839787455, -0.6642431462292846, 1.5564028464468986, -0.1586815058735116, 0.6099306071219655, + 0.8180887224937807, 1.9911546339103818, -0.005984685083011421, 0.6777759409892354, 1.7289968623869099, + 0.5264262640237458, -0.511038272902959, -1.7235775305068346, -1.138944875679032, -0.9623892814614488, + -0.6380738572168294, 1.8832106250881075, 0.028541651530706424, 0.7394956760616829, -1.5455450050824036, + 0.598697699776686, 1.44227094769795, -1.842926293477114, -1.9786511228960357, -1.774125089606943, + 0.4273755931309067, -1.1833540770674968, -0.29742688579612864, -1.7932368057978882, -1.3999703979662605, + -0.5494229951060436, -0.7692231154827809, -1.0112160791506497, 1.633993910846237, 1.3945699010195831, + -0.8649776103569309, 1.921348771224042, 0.832322610715301, 1.3754060709990767, 0.3497018723561567, + -0.7191957838389857, -1.9794221990125722, -1.84384806203993, 1.2324522851211803, 1.7698494016317143, + 0.006624102198243165, 1.5911519918365267, -0.762455861009844, 1.4479210196035108, 0.7818151145500849, + -1.9876926272814606, 1.8202062970885162, -1.6446357331454369, 0.8692666690506741, -0.7358532212979823, + -0.8444806659707744, -1.6015224446062994, 0.7479281419258141, -1.6523782603794155, 0.6710725185977369, + -1.1710932073100304, 1.4784513737588512, -1.212966513263102, -1.3741809040280142, 0.25437428444308896, + 0.8440351752665407, -1.1722116121570672, 1.104161389783421, -1.645735790976162, -0.4286533806712738, + -0.37044520217626875, -1.574330285391767, 1.4899314272896893, -0.8495642882336822, -1.714377156019676, + -0.4893435327563349, -0.7616337581393848, -0.5339391487933929, -0.3003289730553087, 1.3489307896261735, + -0.14680109166432054, 0.1026969670558735, 0.32430953678969043, 0.1795871726769951, 0.9696062238740311, + -1.5296687271207166, -0.2770372362376037, -1.0409472934130868, -0.17306368093190905, -1.1960408781183967, + 0.984219061951209, -0.4077661181651919, 1.7423047847942446, 0.5608878908901787, 1.4329493489434109, + 1.8986413512869937, 0.19154199669760352, -1.3315756935180012, 1.8870822754754517, 0.5674631415439482, + 1.1017148980678542, 0.7256621357674105, -1.8682573426264009, -1.2687446906641284, 0.5430939279068951, + 1.8279931413962558, 0.15890686973919443, 1.394841983124743, -0.8330211159668224, -1.2412716683059033, + -1.1755274256803165, 0.3146422214936937, 0.5127310756940888, -0.6223681329826247, -1.3728009148038876, + -0.5073994704733549, -1.1727465329222264, -0.07518002339175833, -1.6218851358655701, -1.3314808424730247, + 0.2696107099425271, -1.353815758928219, 1.6070801592460056, -0.7018653814032136, -1.594649470877921, + 1.8662880614030657, 0.009632792539534307, 0.885433106263176, 0.7081454198732997, 0.12480572241808119, + -1.9002637028711113, -0.8823815470565757, -0.12794198437065507, 0.3682196882451354, -1.1962414622570767, + -1.101920787984521, 0.1703217046774217, 1.2755057257388405, 1.2757461273763866, -1.7253598839195732, + -0.3935586680170111, -1.6790297555951925, 1.1726640337873802, -0.7187759606615787, -1.5997974808572053, + 1.9512036824878933, 0.8991982283799391, -1.516998597379371, -1.0918962406357053, 0.12845929863120276, + 0.387447437135819, -1.6766371647631972, 0.4172435231617522, -0.8587881195399367, 0.8973805509978297, + 0.5384910477202398, 0.22290981983700497, -1.9824980848037859, -0.19789410371539873, 1.0396641208977249, + 1.9498654847750698, 1.752979186273122, -0.10251547854421528, -1.7031576116596918, -0.6422947693243835, + -0.5947775282776488, 0.25094777162345583, 0.5519773563378578, 1.1845669608153342, 0.07011886849473115, + -1.5689347607142432, -0.9068208446502926, 0.4518736648271817, -1.1908598340431444, 0.9123278060019366, + 0.3808045721687314, -0.5161116183400685, 1.4633312728276353, -0.24955275031843804, -1.9270793627181808, + 0.5510310380033525, 0.002103402836195478, -1.9722027133603266, 0.8207770388309132, -1.2709862666051333, + 0.03660015849392373, -0.08721025552259398, 0.1480447971653538, 1.3975878551198289, -1.8688681862560603, + -0.2735983144132472, -0.29150197793885635, -1.349553505848272, -0.14289894302424067, -1.0632608448362548, + 0.9197316019995538, 1.6766092374653363, 1.4333994578157911, 1.8497508886723608, -1.8365902161760914, + 0.3329875047259945, 0.28711035354851955, 0.018743287980965917, -0.47550704561352664, 0.026002587809994537, + -0.9815518239812109, -0.30422655490353545, 1.1236748290508274, -1.996970334350796, -1.663190926732148, + -1.4930228184840004, 1.2293686779591093, -0.11228295031816327, -1.399262159949875, -1.2745774075202778, + 1.0404471355251506, -0.9042932188930193, -0.483855773240883, -0.051899262666108115, -1.1517591694487734, + 1.631117268451015, -0.4341760983538707, -1.5093199848977354, -1.524695207871412, -1.179033179719653, + 1.203939869858461, 1.2278371883112191, -0.7764972190751465, 0.12469436067847539, 1.1254668267275294, + 0.1253270059252225, 1.02529025377972, 0.37477534712132243, -0.816607896481754, 0.7652933238577306, + 0.0816252203587613, -1.6877073529228523, 0.282188424454314, -0.48899417877023144, -1.10579595806544, + 0.4180711569457314, 1.239608967084651, -0.8553327976952234, 0.7601553028351749, 1.017191993054694, + -1.561711107008871, 0.18166558203866234, 1.4575039351725838, -0.7919992885427041, -0.05528739747934974, + 0.5393789182198327, 0.9208003955213648, 0.8037584630910892, 0.2508199691349171, 0.5025718274381168, + 0.40223725437742086, 0.43401128486340124, -1.918673978558985, 0.38895512761013773, 0.9647875436316022, + 0.356426573504554, 0.7676218046110401, 0.15946706730485438, 1.2727737024033576, -1.1428215846133938, + 0.36778995418490545, -0.41392909578544224, -1.550642999283478, -1.7016418383565188, -0.3516276355010701, + -1.6424434547903983, 1.2296355686757101, -1.3262048004001983, 1.9866748350391505, 1.9039145370701833, + -0.4605978047947623, 0.37289955561548194, 1.2909351136100344, -1.4775326769813537, 1.8608708474080071, + 0.6440656172393684, 1.4358923542702868, 1.6635530454398575, -1.7844300247360296, -1.470415868795488, + 0.21864396672047892, 1.5488664195436606, -1.0864322992770177, -1.550780881959068, -0.8331945313037004, + -0.9367699280324953, -1.9013228249406309, 0.7807264098375688, 0.06677827961955263, -0.4865949947067687, + -1.9079733463147095, -0.8445233464370387, 0.4065074139836655, -0.3310839064029283, -0.2904445573034993, + -1.1753367420636245, 1.3721435340052208, 0.36660883645931097, 1.27723053302687, -0.9359216637576937, + -1.732231846976478, 1.0644600709999477, -1.6378422934868384, -0.8826400850725795, 1.950622879844948, + -1.9911319096225792, -0.6598073662934398, 1.8955996856482145, -1.4071132961709223, -0.8795225115767629, + -0.5029228970810946, -0.7734268477225967, -1.716542237524993, 0.04010671043366898, -0.7937284158037281, + 1.030026939297609, 0.9801808342123648, 1.2953427689382302, 0.20610803631475605, -1.672761300291775, + -0.690673451769495, 1.6609033000524338, -1.897131087105456, 1.2029533984904228, -0.5681454803874688, + -1.3646956682920965, 0.22071074912276334, -1.4735886916157908, -0.9695144027680014, 1.626222864433485, + -1.8694559899308487, -1.8879003983306655, -1.0176033048635613, 1.7915586444709328, 1.8810192124623084, + 0.5319984718680058, 0.0113238596202212, -0.09805090157632446, -0.5444299501215024, 1.1135935258682927, + 1.17684427133796, -1.85426568001437, -1.7530946944132086, 0.3038089938756876, -1.5870230070820002, + 1.9106333020042747, 1.3937407603560725, 0.8591788216145968, -1.612956779000272, -1.7151209190289016, + -0.13707423626294535, -1.4389728179178984, -0.05236093609874359, 0.9751452825232896, 1.744306648935904, + 0.7254929535860901, -0.09824503868926815, -1.4925208247531838, -1.985227418605298, -1.405540454178178, + 1.692915817031472, 1.2230668144021735, 0.04262811065188643, -1.1756894733009666, 1.0222275091190456, + 0.4934708666464802, 0.08979456736565261, 0.10059671914562518, 0.7155249975927536, 0.04082949674837977, + -1.715826873724553, -0.979189481763262, -1.1065843508804214, 1.481429410565739, 0.5278383608268999, + 1.4941027771635946, 1.4151786058577498, 0.07974076288029774, 0.3167509060420519, -1.269619345887964, + 1.8667680276727765, 0.527367815431, -1.874110045435497, -1.9373013120064702, 1.5330729150450173, + -1.7833509822122444, -1.7182607692067773, 0.7561591559894678, -0.1056962696530368, 0.13014948563496898, + 1.804101947913626, 0.7276195691909635, 0.021465712639121115, -0.5163553036182069, 1.1855106734783103, + -0.532372100695512, -0.5871635445412506, 0.161643292508721, 0.61018489160484, -0.9869821416193743, + -0.5318766940780302, 0.9532631042147388, 0.4597709134353236, 1.2142228742259942, -0.8224515402258339, + -0.879922983657166, -1.4710925151016916, -0.29851124917883975, -1.6631372706933156, -1.417993373545026, + 0.6364481896978704, -0.6013255938328603, -0.046835161333119935, -1.7247175181005758, 0.9825982199711403, + -0.7264776248635592, -1.463988875635824, -0.5013956201257255, 1.1933878395314643, 1.3056455851087287, + -1.4398688148432273, 1.7038722585040453, 0.46568790654958114, 0.481485333420693, -1.2873930064513237, + -1.1475617778051763, -1.9673375617555031, 0.39874490557435127, 1.0960170584095357, -1.8987243885981488, + 0.36983554057526735, 1.9718844590478293, 0.7894176749822801, 0.06983603687412288, -0.9000466156841869, + 0.6428129286371904, 1.704798225993037, 0.5950045030048496, -1.1678955586471442, 1.237662010594594, + -1.6482921001228146, 0.7270614937877813, -1.0006186813130835, -0.5305400798817805, -1.5252716548293819, + 0.18855276048488978, 1.5437352372976703, -0.9397215004565727, -0.4258934153954881, -1.0950445559616853, + 1.1844079915434298, 0.024990774215178035, -0.16149461270780652, -1.4078837300903269, 0.09499589792836627, + 0.516842370641422, 0.4800833347119191, -0.539291197739594, -1.5117979701954605, 1.354396143788092, + 1.28278689745333, 1.8488206619245648, 1.3022599053953776, 1.6609548614809775, 0.3269713781203789, + -0.38485903666664, -1.464958277436181, -0.3992461504929734, 1.0699961189397085, -1.0135843210651023, + -1.458604697589653, 1.490121083969428, -0.2595359932483561, 0.20854389182544342, -1.7482190390121701, + 1.3007127316326326, -1.9884730509986825, -1.4952841032131454, -1.3179011133758536, -0.388478076479009, + -1.1589100485370416, 0.8387145536985532, -1.3384696651494759, 1.4683529176022008, 1.303145953986827, + -1.3041819891109316, -0.03449749547681513, -0.6608734038387656, 1.1683787754166381, 0.5655509145236746, + 0.37738607602963814, 1.2152762148898635, -0.29353655718583926, -1.0509092280694636, 0.7139081884804019, + 1.5196106141395527, -0.530586207320952, -1.558831387258346, 0.01131046295330318, 1.9344117181061735, + -0.6850503993030497, 0.9331665418290909, 1.1688357095654807, -0.42466684124295995, -0.49121961262440816, + -0.11897540791552874, 0.5942162255141863, 1.7548838522451646, 0.4438013028171106, -0.28183936745813476, + -0.07495854303862437, 0.9303587326961971, -1.3198776631733748, -0.591718773961956, 1.1127108159764676, + 1.3939520197540327, 0.6360105102962654, 1.3722503898910716, 0.1757960098808633, -1.8297470389548955, + -1.9205472381959057, 1.4666198830651629, -0.7830326162911714, -1.4564248566278515, -1.0967977812614587, + 1.164770039819981, -0.5760771475874042, -0.7667709006388028, 1.371522788043694, 1.7326600398634024, + -1.8193902025531763, -1.6090197630011929, 0.09836987546776577, 1.0677415637460363, -1.5307232030478781, + -1.5599516580470505, -0.2609675276531007, 1.0598276568453162, 1.8794380814113483, 0.5316994667209949, + -0.7552146503779023, -0.4617287817040179, -0.3819745586646004, -0.42575961119349426, -1.9237942552613312, + -0.8825058198571423, -1.8728798417790404, 1.7802885739921077, 1.8333435291969842, -0.3098252256281784, + -0.029956863413143964, -1.0772837825116914, -0.08180463340649524, -0.6945910113459792, 1.1668128443816146, + -0.02738480437430635, 0.39293059281590104, -1.6704359314356383, 1.6869956995774205, -0.2294375604199823, + -0.32757809443951746, 1.9764189201357567, 1.5201484938081151, 1.087504317388186, -1.6272710803209698, + 1.0397868469069298, -0.22176854941092294, -1.8820468396186323, -0.5303897107068192, 0.594170569473933, + 1.6960001373937432, 1.9644545152850057, -1.7960342834770175, 1.7873883299813267, 1.3489957623885935, + -1.6820391707003042, -1.5713129762775537, -1.3637851932919034, 1.5936068708950781, -1.2089638711610604, + 1.322028643928432, 0.8929678781012855, 1.8401053408016272, 1.5452683829326075, -0.9171427145163484, + -0.06745535875434427, 1.9379586035273615, 0.5855503730756357, 0.03549855059545948, 0.6527698319031092, + 1.6754602207349976, 0.7728323704391817, -0.9665588877441182, 0.6173545510334506, 1.3120695172827377, + 1.181821226786317, 0.1841309435168954, 0.32318631702986167, -0.790159398034489, -1.385019609574396, + -0.7118643666238835, -1.0439971536099275, 0.017584768122861583, 1.7536303032255187, -1.1965922071808155, + 1.4548082915973595, 1.562828560283652, 0.0920828524560271, 1.892000960124009, 0.3648061373597171, + -1.9613669287159263, 0.841563763070833, -0.9118328355251464, 1.9226565574363006, 0.3988462224271192, + -1.970188432590363, -0.8264337415665439, -1.9090851263430704, 1.8428915650288547, 0.28596991752644385, + -1.6708643684685667, -1.0762549708362332, 1.9492472760488564, -0.17802109704659852, 1.9236671687550047, + 1.9548632849049623, 1.9566450030001414, 1.3303550873049677, -0.5915124672929295, -0.0037832544010933944, + -1.6026781861800838, 1.7578833516354813, -0.2956678774270767, 1.4060455643195402, -0.7370157759032727, + 1.8789198126787916, 1.1123902493105078, -0.8769185681462304, -1.2618214191177506, 1.8674610245111278, + 0.5103075356648485, -0.020685118611023512, -1.407221324818173, -0.7491588381751608, -1.3743460812306214, + -1.7710712130536228, -0.19455369352318552, 1.3434990212660862, -0.5544338320325721, -1.324247058015053, + 0.8874849369101687, 0.6838871095643375, -1.5617313105262172, 1.204432716341258, -1.4981479923604955, + -0.06499977622096687, -0.8060264147106553, -0.36092597365795775, -1.307326777195418, 1.6399424900785, + 0.429157912433868, -0.9915570262096942, 1.5128426032058089, 1.6375586318255548, -0.1737010535017669, + -1.21285453753765, -1.8844155037723906, -0.2590630754224348, 1.7328565249414716, -1.260633142919116, + 0.3637043664955444, -0.48087365705468965, -1.7001295586898113, -1.0775016378447075, 0.2620695698901221, + 0.5015363913767086, -0.42080100290276246, 0.5338170065286052, -0.43568151602634764, -0.744286733837793, + 1.57647267103789, -1.2491109283310529, 0.10032144655805375, -0.46919353377702855, 1.9415827315644636, + 1.2111393515469855, 1.744725164783687, -0.6871612254817352, 1.406736078990102, -1.063724178982385, + -0.904699966390976, 1.5681407930006221, 0.79849818604837, -1.7759907970834616, -1.6325947440964974, + -1.0309733602826086, 0.29563414198237936, -1.7157737037653629, 0.2876568188935451, 0.21411659926835913, + -0.1601632043965786, -0.02605079418095091, 1.2041639219664182, -0.6351323647136597, 1.1149646585336592, + 0.6657515663650084, -0.4672646227384094, -0.5117766415018226, -1.4643244157794149, 0.39081520672097003, + -1.502649477455031, -1.637368884151761, 0.34542161036123176, 0.060151105688381, -0.5045040651104555, + -1.761988723037204, -1.9197872473179176, -1.3665270161331975, 1.3928026939637972, 0.39218445250695577, + -0.8470063024385848, 0.009038121027233892, 0.46871439485211575, 1.459827780771806, 1.4853766551455694, + -1.2321752545416356, 0.3748806345040103, 0.20582729258619814, -1.4266279966077402, 1.2950255786963805, + 1.7125611822808544, -1.407545517068188, 0.5169179018491512, 1.8595592751857541, 0.9487671455033482, + 1.9467423989905699, -1.5919149626150748, -0.4630901723451881, -0.698284068975914, 0.6197574561950008, + -0.8199869405915381, -1.7196626702920055, 0.6036024034626806, -0.8348164600145518, -1.7650166093756132, + 1.5829990521620996, -1.588645487863901, 0.7633248861408699, 0.5800948434762754, -1.7159447523887836, + -0.3836837699904496, -0.9746560067630572, 1.4480442893861705, -0.24403527878135645, -0.6397662241706819, + 0.956203271386264, -1.4601856308265049, -1.5649468816584298, 1.731664582215319, 0.9679933953420496, + -1.9722379093946607, 0.24076423675934056, -0.19244242272389211, -1.3854799949067935, 1.3744990882455346, + -1.121046645776083, -0.4342567706309435, -1.7159646482293107, -0.9317859666979054, -1.698219647396134, + -0.8288368620433939, -1.3875410583085985, -0.16399331338641066, 1.4667160798353667, -0.020345764369364083, + -1.9585520591695529, 0.9886716666217517, -1.0701744437434098, -1.3248249591382057, -1.3272246201915312, + 0.906046259715148, -0.9554587301398687, 0.16744253332193715, 0.40874734944503466, -0.7237514235199383, + -0.8028952942996463, 0.9478548199038599, 1.6268191787625108, -0.7376232063503751, -0.6874490141085632, + 0.20469380737641973, -1.940886393624119, -0.9715176541080677, -0.11409081023343237, 1.8208884259904847, + -0.05753377002269655, -1.2533113228725696, 0.18235199190840046, 1.3670427559403047, 0.7183594427747524, + -1.9834311091439476, 0.09488256814231644, 0.07406140049599319, 1.0427950622016802, -0.7928805141629418, + 1.7221214208634228, -0.06548459693275177, 1.6984102601031559, 1.8777510809050026, -1.735259524674964, + -1.1416240368033357, -1.7612022583614682, 1.721880360655705, 0.14372177475853665, 1.9269311955654835, + 0.19978809107127216, 1.0299566806165856, -1.7617419918814026, -1.2737765895488096, 0.7789099525859564, + 1.9816257384474012, 0.4482897887627919, -0.9051913536142644, -1.152506387584042, -1.8817136441487783, + 1.1054935295772461, 1.577999662025542, 1.600449927735128, -1.4919075331081064, -1.9550057574515671, + 0.1306184124670624, 1.4754764229533928, -0.808023880270273, -0.21695285993080393, -0.539628797891055, + -0.7836468765498132, -0.8815668388678288, 1.8917264703112755, 0.028934119940069003, 0.06879472114883711, + 0.4407647131615784, -1.702696302284755, 1.7815067716148931, 1.7950168026349171, 1.1405438335719111, + -1.1434283018085534, -1.720238715793207, -0.7729623733152229, 0.17672006075090962, -0.14942755614865622, + -1.6229777838891115, 0.3793725781830055, -1.0113407389657345, -1.9280392460441265, 0.7422498462017861, + 0.8331559663193939, 1.3063596659922263, 1.8113167679814106, -0.1401093760534291, -0.24674083884906395, + 0.15509679692376732, 1.8667087827355466, 1.1906398286118094, -1.673307806924564, 0.41063702592861695, + 0.2436862014560477, 0.24067383021132027, 0.22686603382511628, -1.7295093225806442, 0.6565075922933001, + -0.5514412373381097, -0.5236684516031653, 1.8733248509057603, 1.082970345098504, 0.3340204503283841, + 0.5450315229688343, 1.0954041212853163, 1.565649272477267, -0.5469992182522905, 0.7193108029242588, + -0.9050070254533322, -0.5121370718971949, 0.962566205706815, -0.24631520617082092, 1.3340325816997325, + -0.8820024080231894, 0.22736077826137535, 0.2252389330707789, -1.947448723525529, -0.9518843625899782, + 1.7502182429516546, 1.558646352665332, -0.838440251378624, 1.541757246903681, 0.44677553405529213, + 0.9918545507928869, 1.060951650228274, -1.3653319701374311, 0.2635328688559797, -1.6894618652561055, + -1.9316398959917604, -1.6545844047461316, -0.8374565390669062, 0.5467667551875302, -1.1703334497283162, + 0.79898936445238, -0.48742537394255603, 0.05126348262407365, -1.0630031367885744, -1.6755563384807575, + -1.7470496911251123, -1.4045037572548411, 1.697678496203098, 0.541058257415223, 1.9355948975325852, + -0.8470115353500489, -1.2030885197848056, 0.8919323754916997, 0.0702516207867685, 0.5155253592422371, + -0.3579514965668338, 1.7112737380442171, 1.9947965056065957, 1.2741397687110538, 0.09885151767767297, + -0.9770807797341039, -0.11682236263324342, 0.7272198637411007, -1.987824039940028, -1.1358258258310752, + -0.11090836034305251, -1.9915135816887366, 0.39056056844969866, 1.2932859858303178, 1.7109978939050503, + 1.1846928384025448, -1.7330449982026206, -1.1525984164585106, 1.104166927781134, -0.28565377527521196, + -0.9685059019914002, 1.7093828969134002, 1.9709107005494806, 0.049031526597832276, 0.4472417265612556, + 1.0921859039999235, 0.8763632205063123, 0.8161138478535914, -1.0720275802414108, 0.7266994153226873, + 1.233185460886041, -1.127435043988318, 1.0918127239321773, -1.8540096367958645, 1.9681192361925266, + -1.176325917090126, -0.30265014266672097, -0.44524467812690727, -0.9978154618024702, -0.667478816738317, + 0.15065079333379305, -1.0302715841959227, 0.829863553229278, 0.8134089689909292, -0.6474889076993566, + -1.079618527738825, -1.783292588379826, 0.748112963221554, -0.6286053844150628, 0.48331041409284303, + 1.663305348437456, 0.18479680885937455, 0.6293791717008288, -0.6005275360880811, 1.5747695362530774, + 1.5708476785905807, 0.5755861487097542, 1.2041008720516082, 1.6685888824542738, -0.10278064261508835, + 0.9057150675313927, 0.6510335974298398, -0.10744692672758216, -1.7129461062136837, -1.4064873182457918, + 0.4781316094642234, -0.37189635012197275, 0.16614793992804522, 0.03645962620683285, -1.6224894209420242, + -1.8138940006820983, 0.5069783696842336, 0.6849365239989318, 0.8037589654894051, 1.979213666270276, + 1.6861127134381242, 1.2233661916798626, -0.3986509966310168, 1.274497735591801, 1.605857523477285, + -1.2118797206236485, -1.0066619307873124, -1.358968189183389, -1.9510798049888383, -0.9808314235618916, + 1.742926920936518, -1.1022645984613817, 1.0929594394621382, 0.48488158650621127, -0.32877770055973077, + -0.47650834081572935, -0.5160849885006016, 1.3738126318494883, 0.8827072110361662, 0.48644110690758247, + 1.0382179714335322, 1.6721919595070132, 1.341715329629717, 1.7295025892939409, -1.522344995293861, + -1.5965490751871654, -1.7983927509857223, -1.0759710407128011, -1.3793282201703079, -1.443902375079908, + -0.9426382639949571, 0.5602210832754357, -1.0965977429851064, -0.19857124750589872, -0.7431182359930233, + -1.2699260459939286, 1.4876549876992726, 1.6274319173214051, -0.3309529465344534, -1.9454352826883534, + 0.12935585140981676, -1.0093456723551508, 1.7243377444859647, 0.10127369924344443, 0.697537788222963, + -1.521134755613331, -0.442714777461525, -1.6896188579102178, -1.9330985764980921, 1.9140786772267155, + -1.2925077312416482, -0.9509978830442094, -1.1889787216631778, 0.795835379830006, 1.4837581515063887, + -0.8597344665233413, -0.7448025823499504, 1.7455825639820093, 0.33723505300912304, -1.8208678041990423, + -0.12753920031860666, -1.757360720716986, 0.8256076807737855, -1.9972549760931448, 0.844750409785961, + -0.9594803067513551, 0.7862268813645183, 1.7393046013815212, -1.161126895447727, -1.6347790700422697, + -1.3348870119154936, -1.1621632421015011, 1.2696646718252413, 0.4845759791644788, 1.0668299384975475, + -0.6942334010657198, 1.4734240949292259, 0.4282074397978146, 1.6699946816827183, -0.6802086123370898, + 1.9208442056043609, -1.8532082289660545, -0.1592261674427098, -1.2431462763761214, -0.7286614982674164, + -1.522868872353036, -0.3825873577199159, 1.431979005569458, 0.43719966684470446, 1.6478260330278633, + -0.06620691473965401, -0.36945868484144917, -0.3809516652838498, 1.6855172591399752, 0.31969027259376137, + 0.09157179754578149, -1.3138870107882425, 1.4208147318276607, -0.03157398665509881, 0.03702456744844529, + 1.4819698957492982, -1.6015809663105944, 1.8331399913105164, -0.6094891041007129, 0.9393020005799118, + 0.6313754553821562, -0.3128111370670492, -1.324295564232262, 1.7609361120800635, 1.5935407847968914, + -1.280640014867119, 1.4668684416985176, 1.460389700948717, 1.0299991397017587, 1.2266139129378075 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.1067028045654297, -1.6136860847473145, 1.261694312095642, 0.5976569056510925, 3.0194122791290283, + 0.6702871322631836, 0.10282492637634277, 0.3735429048538208, 5.239527702331543, 0.9048061370849609, + -2.0244898796081543, -2.0475926399230957, 0.08633577823638916, -0.7819606065750122, 3.528346300125122, + 0.796191930770874, 1.819821834564209, 3.565944194793701, -1.9118807315826416, -3.2814927101135254, + 0.6710286140441895, -1.8751076459884644, -1.4806747436523438, -1.796984076499939, -0.7681794166564941, + -3.1337573528289795, 2.5466723442077637, -0.710807740688324, -1.3013415336608887, -2.010349988937378, + -0.5839980244636536, -3.2363412380218506, 0.34071028232574463, 1.8070162534713745, -2.2573466300964355, + 2.2563209533691406, -2.721301555633545, -1.7052738666534424, 0.020109310746192932, 1.8161461353302002, + -1.2339590787887573, -0.3481786847114563, 0.14459623396396637, -0.2792869210243225, 0.3242926299571991, + 0.8492016792297363, -0.9436420798301697, 2.8654322624206543, 0.018474817276000977, -0.33994853496551514, + -1.4109879732131958, -0.17846572399139404, 2.4232289791107178, -1.366482138633728, 0.8393897414207458, + -0.06912710517644882, -0.1260005384683609, 0.14877259731292725, 0.8378958106040955, 0.2032637596130371, + 1.6857033967971802, 1.683059811592102, 0.020472168922424316, 1.4638370275497437, -0.2821274995803833, + -4.081655025482178, 3.739361524581909, -0.5193736553192139, -0.5895893573760986, 0.13349902629852295, + -0.8314229249954224, -1.955358862876892, 2.2120022773742676, 2.9806900024414062, -2.594862461090088, + -2.5279524326324463, 2.2374868392944336, 1.651476263999939, -1.7969365119934082, 0.5194283723831177, + 3.9661808013916016, -0.6912452578544617, -1.615968108177185, 3.2498717308044434, 1.6176962852478027, + 2.765564441680908, -2.780879497528076, -0.943089485168457, -2.211867332458496, 1.5465649366378784, + 1.5535883903503418, -3.7050137519836426, 0.04439067840576172, 0.36327028274536133, 0.9592444896697998, + 2.2070963382720947, -4.852853298187256, 0.6495039463043213, -1.601344108581543, 3.303436756134033, + 3.5125482082366943, -3.4023211002349854, 0.7367992997169495, 2.8616340160369873, 0.557513952255249, + -1.2049691677093506, -0.19210883975028992, 1.6728510856628418, -3.3260436058044434, 4.706552028656006, + 2.2566816806793213, 2.644940137863159, -0.9390113353729248, 1.6396602392196655, -1.3574936389923096, + -3.6357059478759766, -1.3324334621429443, 0.8182520866394043, -3.782191753387451, 2.5362539291381836, + -2.8861687183380127, 2.2147746086120605, -0.7912830114364624, -1.3549126386642456, 2.932422637939453, + 0.6247330904006958, 1.6168872117996216, 0.9066742658615112, -1.156375527381897, 2.196871757507324, + -1.3269041776657104, 0.7688918113708496, 0.02223837375640869, -1.3422014713287354, 2.5085129737854004, + -0.8842201828956604, -2.039457321166992, -0.0754881501197815, -0.4683438539505005, -2.3120336532592773, + 0.4231855869293213, 1.7217100858688354, -1.4091691970825195, -2.062229633331299, 2.0696098804473877, + 1.2929754257202148, -0.21851062774658203, 2.792795181274414, -0.24259614944458008, -1.6432653665542603, + 0.2709762454032898, 2.5165672302246094, 1.4215764999389648, 2.406688690185547, -3.7345216274261475, + -2.1278839111328125, 3.311349868774414, -4.237924575805664, -2.4865145683288574, -0.4375068247318268, + 1.7486937046051025, 0.9667145013809204, 0.7027313113212585, -1.8740135431289673, -0.3525621294975281, + 0.19565200805664062, 0.40774744749069214, 2.2967820167541504, -1.8403133153915405, 1.831811785697937, + 0.9851721525192261, 2.7873969078063965, 1.0879806280136108, 2.5585243701934814, 1.9414751529693604, + 1.6000714302062988, -0.12208014726638794, -2.56121826171875, -4.894813060760498, -2.881957769393921, + -2.041257381439209, 2.9550018310546875, 0.5040202736854553, -0.27999716997146606, 1.0042527914047241, + 2.926683187484741, 1.3717838525772095, -0.24589979648590088, -4.2212233543396, -2.1938352584838867, + -1.6489169597625732, -3.442727565765381, 2.948969602584839, -2.7220163345336914, -3.187354803085327, + -0.34392428398132324, 1.470370888710022, -1.630984902381897, 1.2510205507278442, 1.1136020421981812, + -3.759488344192505, 1.4942673444747925, 3.067783832550049, 3.345754384994507, 2.6331236362457275, + 0.9775646328926086, 1.2827643156051636, -2.623198986053467, -1.1612101793289185, 1.7932779788970947, + -0.332869291305542, 2.42099666595459, -0.9636011123657227, 4.5822649002075195, 0.8944255113601685, + -3.2404866218566895, 2.7085609436035156, -0.4827519655227661, -2.3480019569396973, -3.114384174346924, + -0.8162459135055542, -0.9214845895767212, 1.2764832973480225, -3.152130603790283, 1.567040205001831, + -1.699249505996704, 0.3841613531112671, 1.300299048423767, -2.6244685649871826, 0.1572742760181427, + -2.503662586212158, -4.367088317871094, -0.9085763692855835, -1.3322471380233765, -1.8894531726837158, + 0.7199447751045227, -2.851144790649414, 4.080941200256348, -0.541861891746521, -1.1072325706481934, + -0.6561694145202637, 0.40478527545928955, -0.8838909864425659, -0.5028785467147827, 0.7957435250282288, + -3.4829330444335938, 1.046553611755371, 2.5124118328094482, 0.3735085725784302, -1.0879991054534912, + -0.09173977375030518, -3.4051504135131836, -2.2644267082214355, -0.9162223935127258, -3.4872522354125977, + -2.355233669281006, -1.8244541883468628, 4.704746246337891, 0.4475516676902771, -1.2546875476837158, + -0.44408249855041504, -1.924820065498352, -2.729738235473633, 4.391683101654053, -1.0688762664794922, + 2.174078941345215, 2.718625068664551, 0.7366507053375244, -1.9571187496185303, 1.1222915649414062, + 2.276261806488037, 2.0843756198883057, 1.5358469486236572, -0.14141148328781128, -1.5720349550247192, + -2.324619770050049, 2.1180672645568848, -0.757960319519043, 1.402897596359253, -2.846881628036499, + 2.5358057022094727, -1.274275541305542, -0.14995357394218445, -1.3371965885162354, -0.4439084529876709, + 1.4503703117370605, -1.6082179546356201, 3.0019733905792236, -0.9571952819824219, -1.6500767469406128, + 1.6778243780136108, 2.374703884124756, 3.5679006576538086, 0.20018166303634644, -1.0103645324707031, + -1.480147123336792, 0.19532525539398193, -2.786205530166626, 1.3784115314483643, -1.2978419065475464, + -3.5328032970428467, 3.0164525508880615, 2.0895931720733643, 3.5052330493927, -1.9349405765533447, + -1.311628818511963, -1.7713117599487305, 2.886934280395508, -1.5496007204055786, -0.8046841025352478, + 1.5652999877929688, 0.1403025984764099, -2.6447391510009766, -2.0337233543395996, -0.22587481141090393, + 1.8628309965133667, -3.466338634490967, 2.2385408878326416, -0.858932614326477, 0.6435102820396423, + -1.1014156341552734, 0.6221705675125122, 1.3742595911026, -0.24308213591575623, 1.8533508777618408, + 0.14410161972045898, 3.0187618732452393, -0.33525052666664124, 0.290519118309021, -1.2579193115234375, + -1.3335667848587036, 0.4902459979057312, -2.2434842586517334, -0.6882419586181641, 2.9724576473236084, + -1.2139863967895508, -1.9754445552825928, 0.11754357814788818, -3.2436463832855225, -0.29947084188461304, + 0.9013328552246094, 0.025318264961242676, 0.9405116438865662, -2.2489869594573975, -1.2323944568634033, + 1.4659011363983154, 1.380167841911316, -5.245995044708252, 3.716740131378174, -1.5962101221084595, + 1.7039341926574707, -0.24453751742839813, -0.47277745604515076, 2.6836142539978027, -4.659006595611572, + -0.2703670263290405, -2.802849054336548, -4.558082103729248, 1.3134486675262451, 1.3934195041656494, + -0.9713399410247803, 2.829873561859131, 0.5422300696372986, 2.14626407623291, -2.411435127258301, + -0.8668385744094849, -1.579006314277649, -1.0427988767623901, -0.3021366596221924, 0.7571608424186707, + -0.7852025032043457, 2.0103890895843506, 2.875030994415283, -0.6650004386901855, -4.240952491760254, + -3.7397704124450684, 0.06430482864379883, 1.2097631692886353, -1.621443510055542, -1.7518706321716309, + 2.0040979385375977, -2.1621170043945312, -3.4342057704925537, 0.4494125247001648, 1.0336246490478516, + 0.6141325235366821, 1.1723374128341675, -1.566636085510254, -0.0875391960144043, -1.5110716819763184, + 1.4554129838943481, 0.839878261089325, 2.398009777069092, -0.6458553671836853, -0.7608357667922974, + -2.0972063541412354, -0.596686601638794, 1.327064037322998, 1.2332861423492432, 0.643580973148346, + 0.2491741180419922, -1.1464729309082031, -1.6413570642471313, 0.4765915870666504, 1.1993881464004517, + 0.2358156442642212, 0.8658393621444702, 1.8936083316802979, -3.0983033180236816, 1.2818799018859863, + 1.0561144351959229, 0.18877224624156952, 2.373169422149658, -2.1537320613861084, 1.7804971933364868, + 1.7559447288513184, 0.5495958924293518, -0.29311543703079224, 0.7076770067214966, 1.3824928998947144, + 2.5599937438964844, -2.2310054302215576, -1.3870820999145508, 2.705214500427246, 2.692167282104492, + -0.3191862404346466, 2.2299273014068604, -2.8660874366760254, 0.04656076431274414, 1.0372791290283203, + 0.9051024913787842, -0.7127535343170166, -0.346563458442688, -1.8466299772262573, -1.776979684829712, + -0.7937185168266296, 2.6496312618255615, -3.1376733779907227, -0.5262937545776367, 4.203805446624756, + -3.0495786666870117, 3.059788465499878, -0.6179596185684204, 1.5632293224334717, 4.387739181518555, + 2.1877965927124023, 3.867405891418457, 1.6019251346588135, -3.1097412109375, 0.14593756198883057, + -1.4151546955108643, 2.8710670471191406, 1.281739354133606, -1.9452589750289917, 0.3256327509880066, + -1.0140762329101562, -2.1761093139648438, -0.36153650283813477, -1.9866083860397339, 0.20329490303993225, + -2.189547300338745, 2.0582122802734375, 0.44074079394340515, -3.6016898155212402, -1.0940327644348145, + 0.05166494846343994, 3.9986839294433594, 0.007254809141159058, 2.9994473457336426, -0.17313486337661743, + -2.319499969482422, 2.1396687030792236, 0.23967742919921875, -0.9820348620414734, 0.7810753583908081, + -2.565080165863037, -0.3542521595954895, -0.39312660694122314, -3.611963987350464, 0.042843759059906006, + -1.9587305784225464, -1.0954759120941162, -3.2344908714294434, 0.6816467046737671, 0.7935110926628113, + 2.3788259029388428, 0.8960800766944885, 2.7103538513183594, 1.0750906467437744, -1.3195565938949585, + 0.6368587017059326, 0.09530603885650635, -4.324446201324463, 0.31018364429473877, -1.4680615663528442, + -0.8505295515060425, -0.4297642111778259, -2.9845335483551025, -1.073625087738037, 3.111997127532959, + -1.082578420639038, -1.0510170459747314, -1.0351759195327759, 2.1196703910827637, 3.6743626594543457, + -0.10965263843536377, -2.8268239498138428, 1.9994782209396362, -2.2761340141296387, -0.037992119789123535, + -0.5068368911743164, -2.466184377670288, 0.8389816284179688, -0.9829720854759216, -0.578821063041687, + 0.3714909553527832, 3.968106746673584, -1.0078635215759277, -1.4665645360946655, -0.24487531185150146, + -3.812358856201172, 0.8614283800125122, 1.251778244972229, 4.411714553833008, 3.906099319458008, + 1.1894285678863525, 1.3625565767288208, -1.2013204097747803, -3.780947208404541, -0.7636905908584595, + -1.4467679262161255, 1.9876563549041748, 0.7255282998085022, 1.8526909351348877, 2.2311062812805176, + 0.7617504596710205, -0.3560359477996826, 1.834754467010498, -2.417194128036499, -3.032979965209961, + 0.3447788953781128, 0.193556010723114, -0.4079936146736145, 1.300100326538086, -0.19834625720977783, + 2.222346782684326, 3.1362013816833496, 1.2092983722686768, 1.5581995248794556, 1.3155611753463745, + -0.8380979299545288, -1.2280077934265137, -3.5234897136688232, -0.32684326171875, -1.4621152877807617, + 0.428300142288208, -0.5776108503341675, 0.9278461337089539, -2.4938998222351074, 1.2017678022384644, + -0.2525625228881836, 1.0117347240447998, 1.7265347242355347, -0.10318005084991455, -1.492635726928711, + -1.2622400522232056, -3.0749173164367676, 2.4151294231414795, 1.1957623958587646, -2.7823221683502197, + 0.6365658044815063, 0.18952512741088867, -2.0210397243499756, -0.3540761470794678, -2.876804828643799, + -2.8968381881713867, 0.17692947387695312, 1.728485107421875, -2.2341482639312744, -4.501170635223389, + -1.6425974369049072, -0.9404029250144958, -1.1620832681655884, -1.0455152988433838, 1.3684580326080322, + -1.4598485231399536, -1.6593886613845825, -1.0509099960327148, 1.3251757621765137, 2.258070468902588, + -3.5802016258239746, 2.863391876220703, 1.0157440900802612, -1.5516963005065918, -5.100094795227051, + -2.9607906341552734, -1.1230504512786865, 1.9419206380844116, 0.4938334822654724, -2.3842170238494873, + 0.1679488718509674, 1.827955961227417, 0.10622924566268921, 1.8168610334396362, 0.677010178565979, + 3.0694189071655273, 0.3993656635284424, 2.3529860973358154, -1.4582010507583618, -1.5138496160507202, + -1.5133174657821655, 2.0854310989379883, 1.6874661445617676, -0.6133178472518921, -0.9184160232543945, + 3.041386842727661, -0.8360755443572998, -1.2672674655914307, 0.27318108081817627, -0.906801700592041, + -0.2576174736022949, 0.8814321160316467, 0.9032235145568848, 1.5922852754592896, -0.5044339895248413, + -1.0950052738189697, 0.9084010124206543, -2.3912510871887207, -4.171522617340088, 4.554413795471191, + -0.3333394527435303, 1.5956521034240723, -1.197889804840088, 0.8468800783157349, -0.39928677678108215, + 0.7615669369697571, -3.205524444580078, 2.535108804702759, -0.4366309642791748, 2.1470766067504883, + -0.25451767444610596, 0.23135042190551758, 3.335973024368286, -0.19102385640144348, 0.8413820266723633, + 2.2614195346832275, -1.1231105327606201, -1.3756293058395386, 0.17654633522033691, 0.5028480291366577, + 0.7965704202651978, 0.867662250995636, -4.270709991455078, -1.4976004362106323, 3.333491325378418, + 1.6522053480148315, -3.4461770057678223, -1.0945802927017212, -1.1912789344787598, 0.5186694860458374, + 1.525572657585144, 0.4644775390625, -0.5472983121871948, -4.093353748321533, 1.6807860136032104, + 2.2575550079345703, 0.9947443604469299, -4.168862342834473, 0.09030676633119583, 1.3352301120758057, + 0.37972205877304077, 2.2988173961639404, 0.8671650290489197, 1.040745735168457, -3.316119432449341, + 2.3733606338500977, -2.248332977294922, 1.7465157508850098, 0.19552722573280334, -0.9690064191818237, + -1.8139621019363403, 1.9242961406707764, -1.9793150424957275, -2.789724349975586, 0.18952327966690063, + 0.5084639191627502, -0.054778456687927246, 0.2740379571914673, 2.1619865894317627, -4.095170497894287, + -3.142530918121338, 1.1796610355377197, -0.8848727345466614, -1.2477298974990845, -0.07429039478302002, + 0.9135949611663818, 0.21963024139404297, -1.9909381866455078, 1.99857497215271, 1.4466471672058105, + -1.1016359329223633, -2.8484303951263428, -3.1158666610717773, 4.74474573135376, -1.1900646686553955, + -3.1329240798950195, 2.125332832336426, 1.9798109531402588, 2.6058056354522705, -2.0495054721832275, + 0.028982579708099365, -0.5753974914550781, 2.7390692234039307, -2.3111703395843506, -5.434136390686035, + -3.3772997856140137, -0.37978899478912354, 3.2925407886505127, 3.671295642852783, 0.7639904022216797, + 3.1895627975463867, 0.15414607524871826, -0.6484872102737427, 2.18841552734375, 0.4799572825431824, + 0.1354406625032425, -2.747096300125122, -0.22751712799072266, -0.7596011161804199, 0.5766011476516724, + -0.017207175493240356, 2.4283714294433594, 0.5117142200469971, -0.8030692338943481, 0.44569623470306396, + 1.076960563659668, -1.8645645380020142, -2.2490062713623047, -1.6578664779663086, 0.6149722337722778, + 4.706758499145508, 0.38176798820495605, -4.501796722412109, 2.2427682876586914, 2.1858701705932617, + -1.8162599802017212, 0.9385958909988403, -3.889805316925049, 1.1331977844238281, 1.0191240310668945, + 2.4039511680603027, 4.155160427093506, 4.143398284912109, 0.7778210639953613, -2.2585456371307373, + -1.085227608680725, 1.7663249969482422, -2.8071107864379883, 0.5367544293403625, -0.02325284481048584, + 1.5876415967941284, 2.046140670776367, -3.832700490951538, 0.46683841943740845, 0.5545571446418762, + 1.8768221139907837, 0.7790337800979614, 0.38500359654426575, -2.3040874004364014, 1.1112343072891235, + -2.2824416160583496, -0.026048898696899414, -0.27540627121925354, 0.5449916124343872, -2.154345989227295, + 0.7431529760360718, -0.008791446685791016, -2.407325506210327, -0.46152830123901367, 1.6632401943206787, + -1.7320727109909058, 1.0486053228378296, 1.3803236484527588, 0.3680152893066406, 1.716249704360962, + -2.2865381240844727, 0.14729297161102295, 1.260400652885437, 4.922313213348389, 0.5643207430839539, + -4.42134952545166, 0.464374303817749, 0.59236741065979, 1.2845817804336548, 1.4366343021392822, + -0.0200042724609375, 1.6293566226959229, -1.3861595392227173, -3.4724128246307373, 2.1383941173553467, + -2.1009442806243896, 3.7689297199249268, -2.918327569961548, -0.27357161045074463, 0.9184791445732117, + 1.0513062477111816, 1.9957637786865234, -3.276752233505249, -1.6878246068954468, 4.714818954467773, + -0.9857031106948853, -0.6153162121772766, -3.6428263187408447, -0.30243179202079773, -0.4309789538383484, + -0.03419780731201172, -1.6013574600219727, 2.214989185333252, -1.1272412538528442, -1.6917750835418701, + 1.547987699508667, -0.5724269151687622, -0.47848212718963623, 1.742186427116394, -0.2213730812072754, + -0.8063536882400513, -3.479326009750366, 0.5662966370582581, -1.0524877309799194, 3.702444553375244, + -0.8636859059333801, -1.9768422842025757, 2.1982383728027344, 1.1405737400054932, 0.6146906614303589, + -3.5127429962158203, -0.8339279890060425, -2.914233446121216, 4.411269187927246, -2.3479251861572266, + -0.2184194028377533, 1.7971992492675781, -0.9596229791641235, 0.09081411361694336, -0.4546387791633606, + -0.38310706615448, -0.5399283170700073, -0.2518271207809448, -3.5085813999176025, 0.077769935131073, + -0.5233420133590698, 1.4064757823944092, 0.8371680378913879, 1.9668782949447632, 1.483221173286438, + 1.1757903099060059, 2.9970226287841797, 0.8735387325286865, 0.24936652183532715, -0.7718344926834106, + -0.9049572348594666, 1.9130113124847412, 1.9097952842712402, -3.3667526245117188, 1.1342090368270874, + -0.1385430097579956, -0.6781283020973206, -0.57246994972229, 0.9787319898605347, 3.6297695636749268, + -2.3075175285339355, -0.8269498944282532, 0.8386117219924927, 2.4571895599365234, -0.9069632291793823, + 3.1065073013305664, -0.786931037902832, 0.6695969700813293, -3.896576166152954, 1.6415526866912842, + -2.334099531173706, -2.991877555847168, -0.4740370512008667, -1.0762873888015747, -0.5927379131317139, + -2.2995433807373047, -3.4549155235290527, -2.033919334411621, 4.102828025817871, -2.922405481338501, + 0.9117567539215088, -2.0445048809051514, -2.740710973739624, 1.5500695705413818, -1.4105217456817627, + -3.672469139099121, -0.38663938641548157, 0.1100814938545227, 0.43851518630981445, -1.4003627300262451, + 0.8104124069213867, -2.6236252784729004, -0.40968263149261475, 4.816134929656982, -1.9591403007507324, + 1.2284891605377197, -3.4595632553100586, 2.2904000282287598, -3.7264821529388428, -1.8221375942230225, + -0.44476717710494995, -3.0978899002075195, -1.0302362442016602, 0.49443352222442627, -4.997615814208984, + 2.7403292655944824, -0.9162295460700989, -0.8933342695236206, 0.8124216198921204, -1.433485746383667, + 2.224909782409668, 0.6821490526199341, 4.009047508239746, 2.2991182804107666, 2.677088499069214, + 4.353694915771484, -2.2315590381622314, -1.5339090824127197, 1.626473307609558, 1.9017658233642578, + 0.9766815900802612, 2.563782215118408, -0.2381199598312378, -1.5801150798797607, 2.601571559906006, + 1.7336980104446411, 0.8148760795593262, -3.7112083435058594, -1.3511030673980713, -1.8034800291061401, + -2.950260877609253, -0.4626041054725647, 2.9033288955688477, 2.629671812057495, 0.1508030742406845, + -0.6016277074813843, 1.9043893814086914, -2.119884967803955, -3.048208236694336, 1.3438254594802856, + 1.039656400680542, 0.0982772707939148, 0.29122409224510193, 1.8302289247512817, -1.3148270845413208, + 1.16330885887146, 1.4336774349212646, 3.057568073272705, -0.2635994255542755, 0.3290955424308777, + 0.09837156534194946, -0.4767574071884155, 1.7474840879440308, 0.036291711032390594, 2.057096004486084, + 1.5909944772720337, -1.5554354190826416, 0.45638948678970337, 1.8485369682312012, 1.1001496315002441, + -0.448333740234375, 0.12344249337911606, -2.2758660316467285, -0.18728435039520264, -1.0710699558258057, + 4.759674072265625, -2.308614730834961, 1.3315553665161133, -1.7046191692352295, -1.4248977899551392, + 0.31045258045196533, 0.4682546854019165, -0.5506991147994995, -1.167902946472168, -0.033889174461364746, + 1.2611976861953735, 3.584254264831543, -2.8943490982055664, -1.0763990879058838, -1.9304077625274658, + 1.6052935123443604, -2.1086959838867188, 0.16277271509170532, 1.087416172027588, -4.894248962402344, + 1.6908477544784546, 2.445591688156128, -0.8808413743972778, 1.1168533563613892, -0.35605037212371826, + 1.0386931896209717, 1.4661989212036133, -3.5512571334838867, -1.364258050918579, -2.697364568710327, + -2.279731035232544, -2.2294161319732666, 2.072256088256836, -1.7556955814361572, 1.5964255332946777, + -1.102902889251709, 0.25835680961608887, 0.41171061992645264, 2.6033897399902344, 1.931307077407837, + -3.2382147312164307, -0.6146683692932129, -0.7102252244949341, 2.314451217651367, -0.697837233543396, + -1.0293060541152954, 1.020631194114685, -3.6274075508117676, -1.869258165359497, 0.8600494861602783, + -3.1400227546691895, 2.785465717315674, 1.5925650596618652, 0.5685530304908752, -2.385671377182007, + -2.064885377883911, 1.7148135900497437, 4.472785472869873, -1.6981213092803955, -2.8710732460021973, + -1.5905208587646484, 1.7118372917175293, -0.0628599226474762, -0.024645566940307617, 3.5579423904418945, + 0.043785035610198975, 0.1394520401954651, 0.705904483795166, 0.5603584051132202, 1.9532432556152344, + 0.17339491844177246, -0.5621342658996582, 2.166140079498291, -2.675852060317993, 3.4783935546875, + 0.005500435829162598, -3.0466365814208984, 2.9333863258361816, -3.0374748706817627, 3.3186678886413574, + 1.4340720176696777, -3.4607982635498047, -0.5809862613677979, -1.8670074939727783, 0.31782853603363037, + 5.3407721519470215, -0.11647498607635498, -1.127880573272705, 1.9607118368148804, 1.6104774475097656, + 1.291121006011963, 1.3090026378631592, -4.346621990203857, 0.9649640321731567, -0.3321942090988159, + -0.40106678009033203, 0.6202027797698975, 0.737052857875824, -0.6949820518493652, -1.6063098907470703, + -1.335120677947998, 1.2487294673919678, -1.206929087638855, -3.946974754333496, -0.038611650466918945, + -2.730283737182617, 1.3170448541641235, 2.586440086364746, 0.9609778523445129, 0.9639753699302673, + -1.0613958835601807, 2.2582998275756836, -0.341133713722229, -2.37121844291687, 1.9750676155090332, + -3.339621067047119, 3.8494720458984375, 1.4416704177856445, 1.2632763385772705, 0.49889105558395386, + -1.358637809753418, -0.9810472130775452, -2.008033514022827, -4.180141925811768, -1.8064439296722412, + -2.4055490493774414, 0.8236247301101685, 0.08107048273086548, 0.2551090717315674, 1.6475603580474854, + 2.3599026203155518, 1.4368640184402466, 0.11961638927459717, 2.1902809143066406, 2.4481637477874756, + -3.700338363647461, 1.4484524726867676, 2.2457919120788574, 2.7918150424957275, -0.3214719891548157, + 1.1412039995193481, 2.3801662921905518, -1.1772977113723755, -1.080161690711975, -0.10960137844085693, + 0.23755615949630737, -1.6492403745651245, 1.5414122343063354, -3.5301568508148193, 0.8169034719467163, + -2.7907910346984863, 1.27809476852417, -1.339566946029663, 0.24068212509155273, 1.980497121810913, + -1.103304386138916, -0.9938054084777832, -1.6851425170898438, 1.5913416147232056, -1.6278176307678223, + 0.07544970512390137, -3.0562868118286133, 0.909795343875885, -3.3362603187561035, -0.16795992851257324, + 0.654058575630188, -0.7783452868461609, 3.801968574523926, -2.160207748413086, 2.5146727561950684, + 3.337472915649414, -0.7981523871421814, 0.7141947746276855, -0.44521379470825195, 0.7240567803382874, + 2.8061447143554688, -1.9281837940216064, 0.33615267276763916, -1.3853528499603271, 0.08603048324584961, + -1.1986116170883179, 0.8860710859298706, 0.22947335243225098, 0.5199316740036011, -2.0464673042297363, + -1.6756978034973145, -0.6069002151489258, 2.3337855339050293, 1.6094681024551392, 2.3820090293884277, + 0.8262057304382324, 4.417818546295166, 1.2495514154434204, 0.5728256702423096, 1.7155016660690308, + -1.9175611734390259, 0.8456315994262695, -1.3211780786514282, 0.7857818603515625, -1.7515380382537842, + -1.1971937417984009, -2.808197259902954, 0.7553633451461792, -0.1306423544883728, -3.6146838665008545, + -1.5321255922317505, 1.676008939743042, 0.07836294174194336, -0.9960349798202515, 0.8668472766876221, + 2.137298345565796, 1.2637362480163574, 2.0481185913085938, 0.06509196758270264, -1.6683127880096436, + -3.718193531036377, -1.9311506748199463, -3.367814064025879, 0.012331843376159668, -4.227578639984131, + -0.5755635499954224, 1.583990454673767, -0.86231529712677, -1.809625506401062, -0.3123876452445984, + -3.4403862953186035, 1.0367351770401, 1.1761391162872314, 0.16112631559371948, -4.721268653869629, + 0.07461494207382202, 0.003129124641418457, 4.358157157897949, -0.5005528330802917, 0.24370229244232178, + 0.4912611246109009, -0.6851872205734253, -2.718902587890625, -2.6848104000091553, -0.7679593563079834, + -1.1002662181854248, -0.3475220203399658, 0.002980828285217285, 0.4288327097892761, 1.283578634262085, + -2.613717794418335, -3.3328800201416016, 1.9472748041152954, 1.3897069692611694, -5.0092363357543945, + 2.518123149871826, -2.1634795665740967, 3.7558064460754395, -3.4893569946289062, -1.289191484451294, + -3.036714792251587, 1.6393659114837646, 3.2178215980529785, -2.083487033843994, -1.6242481470108032, + -0.39087945222854614, -2.675858497619629, 2.548295497894287, 0.6715638041496277, 1.0268739461898804, + 1.9520831108093262, 0.2520883083343506, 0.3550087511539459, 3.1073975563049316, 1.3005826473236084, + 3.4222006797790527, -1.5301063060760498, -0.19925060868263245, -0.7549142241477966, -1.7461345195770264, + 0.7768814563751221, -0.8777822256088257, 2.757373332977295, -1.4982573986053467, 2.4916207790374756, + 2.4712767601013184, -1.3903863430023193, -3.3709828853607178, 1.6688079833984375, -0.3028743267059326, + -1.8443574905395508, 1.8113127946853638, 1.7428357601165771, 1.2092206478118896, -0.5405560731887817, + 0.97336745262146, 0.36485183238983154, 0.17854809761047363, -1.6080548763275146, 4.408480644226074, + 0.6553859114646912, 0.9348583221435547, -3.3448331356048584, 2.513455867767334, 0.03000164031982422, + 0.705142617225647, 1.1035417318344116, 3.448455572128296, 0.7622013092041016, 2.463888168334961 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/bias-split-gelu.jsonc b/js/web/test/data/ops/bias-split-gelu.jsonc new file mode 100644 index 0000000000000..23fcb488ca4d9 --- /dev/null +++ b/js/web/test/data/ops/bias-split-gelu.jsonc @@ -0,0 +1,1332 @@ +[ + { + "name": "BiasSplitGelu", + "operator": "BiasSplitGelu", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias split gelu [1,1,2560]x[2560]", + "inputs": [ + { + "data": [ + -0.2565546032426438, -0.4308542731494329, 0.7196725122004919, 0.049034255098233004, -0.6348569496555116, + 1.5952359184580631, 0.8451805251092992, -0.31838310590966934, 0.8187686985927041, 0.5208222841347814, + -1.36437690702164, 1.9883897131851596, 1.5564695381513953, 1.6179857166855847, -1.6925162826813818, + 1.7367654350206285, 0.7210294356326798, -0.16665830399749915, 0.7502629374213177, 0.49174373254887005, + 0.05228599187298144, -1.8492674426703974, -0.13963175206123601, 0.3689713052652124, -0.539726857153676, + 0.4047910328979869, -1.5364223249042084, 1.5401613243237753, -1.5533776810895263, -1.503725108181306, + 1.4980778036916167, 0.3954579111685037, 0.6720461504101429, 1.2082071287930374, -1.9053414457057452, + 0.5161807133471896, 1.876123128617479, 1.4026319996913186, -1.0578832074721918, -0.5667177472093234, + 0.7427152553125103, 1.3309336325346122, -1.250604470605933, -1.5581827636425984, 1.6094888834151977, + 0.8346945311402969, 1.964734881798945, -1.1480686695741893, -1.4692802823913134, 1.443324467502948, + -1.5059857923356113, -0.4324256636772654, 0.47787327572253613, 0.010830153648134555, -0.552568788591528, + 0.976425831078207, -0.3964598095089391, -0.8550178676766054, 1.7829908691025604, -1.4386851327204164, + 1.2586393384114123, 0.6282746108828974, -1.9228284428211788, -1.723501329370512, 0.07875604943128245, + -0.6805248059926141, -0.09444157676724618, -1.2578296605984862, 1.4748594927305314, -1.1326833840447499, + 1.2899356994911972, 0.36076376234182295, 1.0687687434117796, 0.016260539139403285, 0.8790330834375988, + 1.958147414219165, -0.6744379666162423, -0.11384066716133123, -0.8573301336636696, -1.6026243400276634, + 0.947103941481469, -1.9714924616237095, 1.4248854638955484, 0.20743958585530908, 1.1632144973105554, + -1.755170686870751, -0.7194149118428639, 0.14148466161285445, -0.2721811711092048, 1.5603318294054445, + 0.4281402047349676, 0.9777768965070042, -0.4340528216185948, -1.4027853075880978, -1.786451668170261, + -1.0399446061083921, 0.5016682459975055, 0.47912646038882833, -1.4994325065634833, -0.5813174761669684, + -1.7675659877499355, -0.5295448113440351, -0.185677936488422, 0.6133721122115032, 0.8342835422965917, + -1.8851704197237282, 1.8145028466620214, -0.6693817855471478, -1.2825600860901352, -1.9614837268208936, + 1.7498654134131364, -0.9137222020716864, -0.3419963114086366, 0.4718774290086678, -1.087807747894451, + 1.7131835407851748, -1.3262901364654693, -1.3116975227056313, -1.3834647528821646, -1.6786541344436587, + 1.7754973215348864, -0.2608287138885421, 0.7649932766417926, -1.620043859042676, 1.6639963441569705, + 0.8487054193287991, -0.0440434794759188, -1.7095242256175753, 1.285898353046436, -0.8541456852270608, + -1.0975953883519463, 1.6456192019739255, -0.9829899938198103, 1.5115291047537394, -0.1356101800325593, + 0.3163471293050524, -0.7350393764769096, -0.2635909692970628, -0.8961556079942969, 1.35594272476665, + 1.5346085770720466, -1.561332626472022, -0.7482744786530882, 0.7415275042470979, -1.3455806550694929, + -0.17403414417677876, 0.13174239791099396, 1.617942860808114, -0.21372789574955942, 0.1252618392038798, + 0.8356601282506393, -1.4645614042274238, -1.2750059490075998, 0.30114108212769697, 0.718759572394414, + -1.7476419557101002, 1.3774845116886816, -1.5548787928223922, 1.2843809545104072, 0.40166155990403407, + 0.7182073117046643, -0.9178049843649907, 0.23022861413351325, -0.16782915827518163, 0.809538498880066, + 1.5368543357363587, -1.9144180623212463, 0.22370787233170475, 1.754935494591769, -0.300812686337701, + 1.6324016835327004, 0.42942625235999277, -1.821577324695582, 0.7065235323323327, 1.6359230558031488, + -1.6698325562461855, -1.0322767487026727, 0.230813812551105, 0.10476742241238135, 1.8969745358419585, + 1.6692111883958614, -0.13495950982519744, -1.3187891429974368, 0.33479728114140084, -0.3567915710179923, + -0.757581852807693, 1.0471229062934597, -0.15293599163291116, 0.8603181880710826, 1.5915256233625152, + -1.146248153475943, 1.4333844348196614, -0.4549776123350959, -0.24793256123671004, -1.1103266883866416, + -1.8671773666157563, -1.2229905095092333, 0.06266819913029753, -1.0062234212958696, 0.8035035300131232, + -0.7305732410705756, -0.18203039831679124, -0.31014428798975313, 0.3622048186467879, 0.45985097885235504, + 1.4809283400531248, -0.43833106689297807, 0.04401934126909346, 1.5587830597086887, -1.6615450592134344, + -1.0549727490794032, 0.10619243811841983, 1.6457168234317425, 0.8542761080095067, 1.0005810856142707, + -1.5553877748221447, -0.9212057116120924, -0.8894180684757362, 1.2348207479698639, -1.867513894231462, + -0.3703762605562, 0.6409515313866843, 0.07519639585361748, -0.6450469579811191, -1.1954961398506443, + -1.1204107450288525, -1.7282576070578317, 1.3286514210595293, 0.9221546638897049, 0.3164348496087497, + 1.1124119320099153, -0.1657137161424238, 1.8878952460812632, 0.1293077579568429, 0.23358535468938246, + 1.6758942267934263, -1.4201635109571313, -0.15778983434253657, 0.7416701308494114, 0.5883762821820557, + 1.4927448431941999, 0.6935037789859804, -0.7735384747712253, -0.003653232296928266, 1.016177625769398, + 1.663989224847544, -1.860463076082154, 0.586684836283415, 0.35116588158246387, 1.1544623264859073, + 1.900422405885342, -1.4569177282686354, 0.04845063476308198, 1.3176042838894286, 0.7208418723320795, + -1.2204473994940885, -1.4000622968397813, 0.21650376984443565, 1.060631946423273, 1.5077365306547108, + -1.0212630681047514, -1.9198532775330452, 0.7154901442715236, -0.2676631096599271, 1.3808441670529703, + -0.09885367904367648, 1.4777610290353689, 0.056817122949397, -0.21052160911698614, 1.163420787671865, + -0.9278068609037406, 1.6649534188139832, 0.651092543699634, -1.9558728896680586, -0.17393499677517266, + 1.4142540331221616, -0.6414265768865439, -1.796476122987272, 1.3019592923170675, 1.8063975233357983, + -0.09846336861938809, -0.44669537007711746, -1.8509683249130777, -1.2535113000292872, + -0.26272617023889033, 1.618948771545468, -0.24333144562387243, 1.684567206060013, -0.0684671993602386, + -1.2513800864625626, -0.42209759253339296, -0.5204624726492675, -1.2354597847242017, -1.8741410257015954, + -0.4694216976292269, -0.6222814585320586, 0.21189389441688178, -0.07825818775306104, 1.7163253327182595, + -0.5828872666159741, 1.2859942412787504, -1.6213500187165488, 1.3292413342594758, -1.8919626255232016, + -0.5376250474966104, 1.8689278008038874, -1.7144714737277686, -1.0258522713372873, 1.7337707335280257, + 0.7736026666583085, 0.5294771168151202, 0.018158706442022776, 0.06181604275240726, 0.39921503387415935, + 1.1876014889699817, -0.3054069392890115, 0.16369159901033914, -1.7827759163094807, -1.972081887839714, + 0.6481860379313247, -0.35352594703309315, -1.3849570316603321, 0.623606163848514, -0.05008698162477465, + -0.7604090786231117, 0.018133973172791862, 1.4085369489144304, 0.0006579664080073044, -0.945901823508092, + 1.9768973551218298, 0.8031916889887203, 1.787407754969756, 1.4724476919448328, 0.4025600955058959, + -0.8188218839399735, 1.9450326825091242, -0.7081203957970317, 0.2756639708632225, 0.5758621242189532, + -1.3400477780804723, 0.9643250467598792, 1.6902983107628202, -0.9456991797498446, -1.214924995943468, + -1.7096508074160557, 0.6280377071410248, 0.043386564374169545, 1.4018487692877626, 1.9387558240379494, + 0.6404760132931342, 1.9907639837009397, -0.9629506658962566, 0.44160372290058625, 0.5311677453788013, + -1.98443782839251, -0.8098531120098231, 0.6492275695305256, -0.8908778962502977, 1.3356991895363226, + 0.22938352294803988, -0.7322473322214602, 1.5046263929054788, 1.3645550724874669, 0.9961702988823644, + 0.5139435687003742, 1.4501909090527283, -0.7800082224423388, 0.35291863289434566, -1.146090063932382, + 0.30350456267852977, 0.6439700368534709, -1.9972080095325557, -1.6099893116712707, 0.8052080484401971, + -1.0954234253610906, -0.7337556567039494, -0.31964730584596346, -0.0775965423638807, 1.9583401341749038, + -1.2551023178822023, 0.8562333066218226, 0.7147503883958679, -0.7448967487912448, 1.935551654646078, + -0.8599046782377142, 0.2970146461500134, 0.8930246286483605, 1.1182036059156406, -1.197201583834416, + 0.07161352111517694, 1.9408265571293892, -1.0932315538260982, -1.599500716435764, -1.3924326798020594, + 0.11983524432330928, -0.7679553338056699, 1.8370796383852186, -1.2704307623609719, -1.1997771708610943, + -1.888947914144012, 0.55632142669847, -0.34454450486308286, 1.4998341008691707, 1.5090724650893934, + -0.39294559939271156, -0.5834694562549148, -0.5369747247041268, 0.09250855982787431, 1.2568497810992714, + 0.3053435200201413, 1.9380606411981267, -1.7300384067590375, 1.8430090411386875, -0.09705704918122926, + -0.9748664840480341, -0.13310751475233928, 1.8695019507218804, 0.6192212840529407, -1.8594506513740239, + -0.5527467955831558, 0.42368164563035027, -0.3623000694050589, 1.4671722318441134, -1.9926015337743834, + 0.6390178007629022, 0.3930378886525858, -1.999173071586763, -0.10690790413254625, -1.8526590162958945, + -1.7626087675511721, -0.37087181562117433, 0.8387521856704474, -0.19069686876676517, -0.8188898303583647, + 0.12674602912986455, 1.1094188774808789, -1.3557441716004606, -0.6605764711080546, 1.5447276204253857, + -0.8769307424824397, -0.2545000118362921, -1.0291738091907403, 1.7938016119427465, -1.9984894749397677, + -1.2827802315328531, 0.9376743174843778, 1.6630184567373671, 1.3959796712419958, 0.9312602680557571, + 1.1022358675981447, -1.8339474355859497, -0.4903489710565081, -0.8895348660922506, 0.6526243952924284, + -1.3544145682501627, 1.7693021030400526, 0.16080394788547636, -0.012568942703699904, -1.15491020041351, + 1.2537510862653143, -1.0969062449531704, -1.2759047604132565, 1.5734543561168284, -0.10691180117407129, + -0.26934435340525376, -0.9933427134070101, 1.7043012450253494, 0.02549752010787465, -0.08701034102069993, + -0.2674176298500974, -0.08499649536913978, -1.9065384537204872, -1.8548124018210599, 1.4114174077197168, + -0.1003529232852527, 1.4168803569119728, 1.2546078925285231, 0.2836297019652072, -1.741854513232557, + 0.4335127344631058, -1.7722508801895138, 0.8906617987453824, -1.1907389082962334, -0.4570346317067928, + -1.65487466866219, -1.6799510105998428, -0.7345507991250928, 1.0450995087483994, 1.6718914380761127, + 1.9492487109001955, -1.603110762963735, -1.1386896023097748, -1.2314080609240623, 0.9155801473532383, + -1.7678615945378713, -1.2302067186131085, -0.008026479019882515, 1.7076934044219962, 1.2976204179783615, + 0.23979601472188072, -0.31452550079775676, 1.237656674724093, -0.5073967065142844, -1.4684831151454167, + -0.31751104037333455, -0.18632018989018295, 1.9021272175218185, -1.605133256102862, 0.7057429809617588, + 1.131130449558933, -0.37696933268230914, -1.9999062126002407, 1.2947486718643644, -1.5494680428486927, + -0.619492250822522, -0.7119383385868714, -1.2267826883342403, -0.05700963867832698, -1.6002949090802074, + -1.235668500595783, 1.0038351588940309, 0.9591804998296345, 0.3500270456646062, 0.9985378178389439, + 0.37800440518144107, -1.2954921521047265, -0.06652125976806822, -0.5529207993209635, 0.799277104478481, + -0.3195922296721463, 1.278513926080179, 1.7667762251394539, 1.6741322500537104, -0.7688735332563814, + 0.3718086223078023, -1.8559862073263451, 0.6460500799921309, 0.9538174223011966, -0.9599765113376018, + 0.2363354873707415, 1.1507344321172663, -1.5062754878549311, -0.6521456232305631, 0.26796953563726245, + -1.6029382384422037, -0.13330793525728346, -0.4700350464240364, 1.2426614036874923, 1.8410794059261688, + -0.29943068600730616, 0.23080019611542912, 1.1289988352417586, -0.9307260459011948, 1.5075366130642047, + -0.75430635423512, 0.7786360347486365, 0.43281144997343457, -1.9253410424489141, -1.7129308932126097, + -0.7115090993554389, 0.09728563961266712, 0.9612524095722907, 1.2677697737234288, 0.17580981533895557, + -0.07099813019347945, 0.6420914403628677, 1.929710466274961, 0.8224262711223203, 1.197715527419664, + -1.5306699759366271, -0.4379817520480751, -0.6193479789854166, -1.4632019512465932, -1.259260972161262, + 0.8218577982330961, -0.5872036188804506, 0.9619144554878618, 1.9136356497081524, -0.0054634952730427955, + 0.918959735070624, -0.5711442865104566, -1.7438335683612998, -0.09911336983085484, 0.9296115985185578, + 1.4616523373203654, 0.46274661162826547, 1.0360815568791644, 1.2212946749865532, 0.15313173395745583, + 0.822676634908933, 0.4284785502131063, -1.003861397431101, -0.1736541682765118, 0.6790160543077883, + -0.14098755750254632, 0.3123446972161563, -1.1989246192108354, -1.8129870085835993, 0.32605203312965525, + -1.1125265808505294, 1.7561842733362214, 1.7739812007749807, -1.4043301791912688, -0.8943531190371674, + -0.6297367970706924, -0.7981479808539875, 0.8794315514521189, -0.3994743473829594, -1.401573672109106, + -0.8912997118364743, 0.6495541504851907, 1.9230113410471903, -0.6785063630402766, 0.5089535993984642, + -1.8974271070828808, -0.2620702810213116, 0.1991815065696141, -1.207848636028186, -0.8986511371196979, + -0.7061642892592088, 1.9332451504308255, -0.9830761395756227, -1.244182804715014, -0.017993897992697683, + 0.8786296671983029, 0.540637680299664, 0.8969343374913903, 0.7955401708949532, -1.529657800511421, + 1.8639189521351138, 1.067177379712641, 0.7401683858489143, 0.7294620382453241, 0.5904153067983531, + -0.49843243603051146, -1.6804352670210738, -1.0693615378859294, -1.9090165306219946, -1.4036168596502474, + -1.2392401335977272, -0.7753989839018987, 1.8934648648263472, 0.3583307892541949, 0.1338886754962525, + 1.8690213745085558, 1.455690685933316, 1.1699346906107797, 1.6059636790463596, -1.2445702023047094, + -0.26540435557319864, 0.7367777535407951, -0.06848824820724975, -0.35847294758367365, -1.8422274132326661, + -1.0257628610133205, -1.6669678072109697, 0.37716994278958094, 0.238235954755047, -0.2999780255928437, + -1.8840499434617515, -0.2146515363892023, 1.683440503823257, 0.11413961404491424, -1.6059835147695, + 1.1951777078503847, -0.5174290916747344, 1.3440252628311207, 0.21500397686099326, -0.9421891710579917, + 0.0593410986318057, 0.34006850000048683, 1.644097326727004, 1.74676874874047, -1.6742906231992345, + -0.33329412020743643, -1.6254771174048148, -1.6581767039373991, 0.791705460659057, 0.9383035214148672, + 1.7805390345327456, 0.5776366760158806, 1.587860436382865, -1.5903762069130911, 0.16034052878776794, + -0.2414627652627388, 1.2751236768892227, 0.3209997221960421, 0.31176177950076234, 0.6234148156263783, + 1.144504541840126, 1.1535423529138784, 0.11665655599341473, 1.9697764827003628, 0.14558312336598078, + 0.36578124791517297, -1.769415346962682, -0.8303724165129278, 0.44703666963932154, 0.35095942056362794, + -1.197063815711486, -1.5390788457973201, -0.8313129454097989, 1.952907404456571, 0.30612523411761394, + -0.9380264922530621, 0.9822286847681259, -0.12269281399330456, 0.6557752769532215, 0.48870196679108435, + 1.347011507573625, 1.6519563808835915, -0.7385795429014586, -1.048048723619047, -1.0859902684402032, + 1.7187556784188924, -1.8663335499394762, 1.3448325108921972, 1.5973182779732955, -1.9246562924691855, + -0.896157435511153, 0.5381287932952938, -1.7180528810790427, 1.6998135385965343, -1.99001646343652, + 1.7545133161159834, 1.9333853932807186, 0.509746393965961, -0.0675591816104566, -0.5487596842885525, + 1.1654829640998834, 1.2418276538086106, 1.1046551528849635, 1.1541227103848648, 1.3688916432757363, + 1.8682574813085164, -0.9855654818965549, 0.5606398966348412, 1.9470096222814695, -1.9903658034181957, + -0.48344153717960925, 0.6353768282033556, -0.0513621772163555, 1.097950015862886, 0.24567995461843584, + 0.8059244051805585, 1.3521742423136072, 0.47230481196519936, 0.6133818327341913, 1.9135550317308576, + -1.8499763818701833, 0.5894967733509713, -1.920326871370679, 1.6405640996262383, -1.1695280806136248, + 1.4241423736983672, -1.67349053378969, -0.8405344375321118, 1.4551541878625045, -0.9884334361969964, + -1.274246357510064, -0.5067514453663833, 0.5492300144792424, -1.063358889380532, 0.3303561470509724, + -0.9579434760073298, 0.057620205260851876, -1.4742255794760677, -1.6308108759776871, 0.050309644830520917, + -0.4069039662995735, -0.046184249642958086, 0.07949429360802096, 0.12024378717844808, -0.7878588154579758, + 1.472302474285109, -0.29877233638272127, 0.3469939919860652, 1.0030914450113793, -0.033399023382078674, + 1.337733463292298, 1.322404140400704, -0.9344843531853249, -0.7431360466648114, 1.6844959250870701, + 1.361728870488129, 0.9343674315961268, -0.027258361373763584, 1.838512510349391, 1.6503912849225708, + -0.4099533673620952, 1.0611278108159041, -1.6127864143155026, 0.24393478918735934, -0.44450380494664365, + -1.6597118492450527, -0.6490315790577066, -1.7612694260424497, -1.4276602313620197, 0.940659328987655, + -1.263148442418843, 0.5147963537125868, -1.3438202624356137, -1.2843267233008957, -1.2655953556300048, + 1.8008689611198623, -1.0529595001338974, -0.8423440348538298, 1.4068621033910276, 0.007761032105414678, + 1.0935924640537378, -0.09622562614772612, 1.5915223843314257, -0.6702176163118629, 0.773624899360053, + -1.0994916609949348, 1.8469641173001987, 0.8777243936892871, -0.7039703784245424, -0.6127726098737645, + -0.8456557917329288, -0.21546583787614892, -0.9360415939958875, 0.5705354560747677, 1.4330326228154542, + 1.5996006008539831, -0.07367645458791117, -1.5628497471538445, 0.6808136908868017, -0.6639347562269995, + 0.9713336233567063, 0.914773615335891, -1.4236420984162779, -0.02242621653471044, -1.898261801965826, + 1.8572128094883267, 1.5674319640876826, -0.825255227018233, -1.2984530858686814, -1.8947198959796596, + 0.587966233462053, 0.7276082291256891, -0.7118337876069631, -1.9491744048376294, -1.9749840187715382, + 0.6636341374951193, -0.8911042565498537, 1.1064802004273853, 0.5698018751524465, 0.31161591659845556, + -1.4542202550786563, -0.4973178781500307, -1.261090074552551, 1.202428109165787, -0.6502124259985935, + 0.875011031676804, -1.0459657545700463, -0.42208724851625856, -0.056326319293169114, 1.033820082809723, + -1.7520891696720033, 0.8087980713817391, 1.4090272899721707, 1.7934272795205217, 1.625604511965781, + -1.1093024540507015, -1.1836309931833142, 1.2726582815916636, -1.348432635892622, -1.3318038843301023, + 1.8025664988508767, -0.9447266569517438, -0.5044961285117484, -1.6608791173499942, -0.5404247801346038, + -1.8417295161873213, 1.3703119773103376, -1.762654389000578, 0.654516908110395, 1.7444654879377754, + 0.06700313474848674, -0.11772438211787417, -1.4346934108510236, -1.3471639350541027, -1.8846130174975517, + 1.4341206352848488, -0.1781484331915868, 1.8660723422398098, -0.14067413462914313, 0.8690053039391579, + 0.14439977774786517, 0.39859516757572777, 1.338326821281048, -0.3540839451368738, -0.3799467769042604, + -1.6627115692361185, -0.27677435345149703, -0.9817434354820271, -1.7561599198792992, -0.09529627982909172, + 1.0892976295039958, 0.8013191504371955, -1.6378286446982058, -0.2422910554223776, 1.192900022577974, + 1.552603422632739, -1.9305123473184267, -0.20948858983751073, 1.9616644595376727, 1.6353585267617419, + 0.8302939534322658, -1.3824308029410997, 1.4671017215695894, 1.075525400633838, -0.8516882354421424, + -1.764852757300301, -0.6508588489877889, -0.5864767969763536, 1.1318668208548859, -1.2774769138035325, + 0.28054164298575834, -0.7225628233832841, 1.0930686104208345, 1.3757645119448334, -0.8436086923316299, + 0.15317111418019813, 1.0696425581898357, 1.2977798557148477, -0.9586985895916751, -0.9572006108424791, + -0.5173013825178812, 0.9558842805311762, 1.4588133070282732, 1.4300025626225263, 0.9236625391445488, + 0.6245988744867041, -1.4555968079681971, -1.9081528171774265, 1.4969305861374398, 1.8013611610843867, + -1.215264739954538, -0.6460349115684512, -1.7385666591205897, 0.4952215434575882, 1.7813317720717237, + 1.4779804705753685, -1.6448372710981527, -1.7914377349318746, -0.8701351514587072, -0.6086582613844049, + -1.7386736832736416, -1.6630595398426538, 0.06404471619092522, 0.12472538361116214, 0.25173024087120677, + -1.4925192750724054, 1.1326320024353294, -0.6723638830002354, 0.9276409891081574, 0.5160132920113458, + -1.7619800169226787, 0.7729306209435105, 1.4852229135771582, -1.4043797098218525, -0.26979894999654697, + -1.5676449182307888, -0.8285263778416416, 1.7968643376162596, 1.1503154963149846, 1.3682632091264955, + 1.9014644171021402, 1.2035803257561772, 1.5807589620662768, 0.6681530530278019, 1.7430055744872055, + 1.4516895295108698, -0.6636088362253298, -1.3265544726243066, 0.7260245399794885, 1.0068518018712478, + 0.2208570840730273, 0.8459119656002851, -1.8915833254450725, 1.2162158433713186, 0.9752766886721753, + -1.2607916054285697, 1.9003684087234687, -1.6824694825149118, -1.3545700227621689, 0.5912583336167279, + -0.7008114462062913, -0.022208913414461406, 1.4871167887266532, -0.47220337297808346, + 0.001402495408155957, -1.6337795432062068, 1.6557142707874517, -0.21911880468117495, 1.994215681564901, + -1.4675327906481472, 1.9744129850181764, -0.10991781070844464, -1.5582267736493964, 0.9509729601966104, + 0.27383527366630744, -1.0109967848840293, -0.2652951445752292, -0.31773126890169845, 1.9347214689284513, + -0.48900940865557896, 1.0348946328564068, 1.6101647718098393, 1.2224869337691553, -0.24528963586677577, + -0.8282995437227312, -0.74214677104667, 1.9022077987920571, 1.065511429772795, 0.5557978714258756, + -0.8552846035431614, 0.14131421568194025, 1.6415849500887356, -1.0979455229862056, -0.1899406250116744, + 1.882935380340533, -0.47086245203046495, 1.1173180765349162, -0.38373005169056196, -1.9204322585926832, + -1.9947555620271595, 0.25610856180969677, 1.672078838942273, -1.9323275104124873, -0.7955457526088345, + 0.9709167319971037, -0.6356155194456985, 1.6590519014057472, 0.2576466445092809, 0.06826734181142502, + -0.8077944062086111, 1.6116753503934271, -0.360670679232701, 0.665216992360695, -1.3443131827622485, + 1.004588252277319, 1.4090953833875757, -1.6465558763926547, 1.5390983488230878, -1.3071804786368446, + 0.6990024492525402, 1.0093027361688671, 1.097146827869869, -1.1912001568366906, 1.630219037886489, + 0.5744645582929406, 0.6113886454941007, 1.9913520749447864, 0.44093025553359233, 0.08768839917963245, + 0.8058107685571905, -1.033178784078995, -0.8225351475861098, -1.033391812065922, -1.9518439963887113, + -1.4013652343418483, -0.3575891026292508, 0.5391177792156094, -0.589608397581415, -0.8183723923536403, + 1.2663801923646707, 0.2780137273448853, -0.6282934713194672, -0.3515950748704757, 0.32366450506041744, + 1.0378037684314076, -1.4606869437722425, -1.1777315953586012, 0.20077305098986375, 1.1132819638088014, + -1.3403580753455646, -0.8471624166708764, -0.78670588443378, 0.411075409745707, -1.1567621787661597, + -1.644706858279382, -1.6923798602467333, 0.7295344690619672, -1.6826221173466598, -0.5883451091973848, + 0.401937953067649, -1.9630301043637237, 0.9577847247199625, 1.2919304030896734, 1.225292029861409, + 1.0377797459888836, -1.4163731758310805, -1.1848812939548496, -0.6833661697787141, 1.8924034115938815, + 1.0815655745578274, 1.6514910861689147, 0.13638193195368675, -0.7978052236681465, 1.534221720841157, + -0.9153275868493269, 1.8162414196916217, -0.18899261449512395, -0.41163974536752157, -1.7273888453908217, + -0.2259883780746561, -1.5648477168095223, 1.483239033708478, -1.3942275569974942, -0.47490997296002035, + -1.7533267576457146, 1.885960859958221, -0.2403666825901727, 0.5586086137789854, -1.0921597903975728, + 0.31058170316775424, 1.3277885858021783, 0.7148623072607876, 1.557240774068096, -0.37204592237067047, + 1.296024489070251, 1.0544578912529046, 1.6854705704653057, 1.511922732706835, -1.7363032773317322, + -0.2092964598440208, 0.4481058099152593, -1.6737966633530679, -1.090792800318261, 0.6450021408332125, + 0.20592552116749374, 0.8156159228622641, -1.568521435345816, -1.9264895113126421, 1.3354691459199177, + 0.5736596388344699, -1.1328195208816618, -0.9748407260828982, 0.12745284809908775, 0.12755406203831754, + -0.402313518792381, -1.2084251106645967, -0.766591500306526, 0.5831135977679764, -1.9502950151238094, + 0.02368696030780093, 0.9474278651806163, 1.200170490724969, -0.23626565831359603, -0.00155175398078633, + -1.4962274409033034, 1.4172670942494623, 1.6336646415640272, 0.9015591359514898, 0.0028804858578190817, + 1.8517288251219313, -0.5845464836204135, 1.4764716372203344, -1.3353282541360834, 1.2177029719595032, + 1.990053698845105, -1.5060495488639427, 0.3106662987296689, 0.2870946491789992, -1.0746935708007666, + 0.895601302779335, -1.5268871259283276, 0.2678390558787376, 1.864513144299627, -1.509732497733654, + -1.5904836915336782, 1.677740063904718, -1.5200319590355935, -0.7931901348349131, -0.057131201288623146, + -0.5299918958977132, 1.5844048394762984, -0.9074859645216646, -0.6250324408575869, -1.7483344039680153, + -1.648406830815511, 1.790525414138699, 1.1087046319299025, 0.42706999037150517, -1.1252987787944102, + -1.7926103436944185, -1.661929519642099, 1.6441197132107996, -0.7994467312378033, 0.4924994603048294, + -0.08241669477843683, 1.3633290277505736, -1.5418640660707128, 0.05946840692833444, -1.6001798975438088, + 1.1048132720023895, 1.0870814820606292, 0.5831092120375949, -0.4133915802678638, -0.9142815961432058, + 1.9472407631770015, -1.4219510012962004, -1.1714172709742634, 1.4079274160525728, -0.34506772652017137, + -1.0158502812481522, -0.8168949547547397, -1.1457275079452227, 0.12910364073916902, -0.4662867248454887, + -1.9437241965128846, 0.07261938805201762, 1.600763502162561, 1.0777174066413018, -1.2723083195923186, + 1.3113857387077417, -0.5228664205198017, 0.4450488409424249, -0.5683762553586282, 1.8256298065201282, + -0.5555324792306022, -0.5028443495682451, 1.550965251170056, -0.47857481650681066, -1.008285169693293, + 1.9029801635553145, 0.7739617661198315, 1.8099201835531415, -0.3994817059435505, -0.8127756747385817, + 1.4033307810724054, -1.359844813448376, -1.3846355261466767, 1.8540201060398234, 1.0970430821179962, + 1.3778953118217157, -0.6311210216871839, -1.8270928773238353, -0.29073753592093343, 1.063407723752193, + 0.769348666705115, 1.069807859635052, -0.13297318999054397, -0.7495627942438086, 0.4278305696495597, + -1.7534013899377605, 0.20503621122624516, 0.8877917416885026, 1.9219368972107151, -0.7795858126195832, + 1.8045722365205155, 0.01848994995789255, 1.9822395081707462, 0.5682615557282436, 1.096590333951183, + -1.1060317246730396, 1.0869871276155, -1.8681569307257382, 1.9498301468214843, -0.5725242199723457, + -0.754441782550737, -0.0400922717249097, 0.010590885689596874, -0.6969977940409491, 0.6620666861327669, + -1.5969982725685883, 0.17340909047153819, 0.47755863566996126, -1.6291589696000264, 0.5780359168220688, + -0.5173306807336635, -0.9514848124225095, 1.6169705288679577, -0.42893373795490586, -1.1528283547930025, + -1.0400955977716713, 1.9827399061918518, -1.327376858845513, 1.6043081593845727, -0.11039938533269034, + -0.24997046912904874, 0.47014693724974954, 0.729145170631436, 0.7014015249399357, -0.210704593378912, + -0.8898579600723409, 1.127679820439794, 1.379686041589359, -0.781681363239616, 0.2858562618428895, + -0.7131287792008063, 1.7878252016142655, -1.0588662101910593, 0.4893786031875278, -0.8406649057821527, + -0.1534439859760095, 0.2331374640695536, 0.469582749149386, 0.4137313563589631, -0.9769266749700227, + -0.37870419628070984, -0.778681560279427, 0.9076631441596534, 0.9624606550623103, -0.10250544734763523, + -0.2518054637298146, -1.8389925418951307, -1.7523906632013846, -1.8251290048632933, -0.49859600328817244, + -1.5645964425382966, -0.46692843392327266, 0.3025335867203003, -0.8785897670006184, -1.7537151065566459, + 1.9360855593038684, 0.03479121265661611, -1.9402430169146303, -0.8981491188900641, -0.9720525655542991, + -1.2872361339345169, 1.9657361835316278, -1.9654227152525223, 1.4590349841798576, 1.7417951527725704, + -0.7636287264036836, 1.6938802231015364, 0.3969017158868704, -0.9308527980280088, 0.44396078845267084, + 0.8114974124677037, 0.22323733905690712, -1.157000049324795, -1.2116172131012615, -0.9832275983234169, + 1.773233656033006, -1.6481062641009663, -1.9471951041445985, 1.1654342998679619, -1.8679076405021187, + 1.9134708504745168, 1.9270958489182615, -0.4877809076980144, -1.4674512268342745, 0.006115322418878577, + 1.5523105881305073, 1.008791751555858, 1.7292932521498168, -1.2446660428848375, -1.32058622408507, + -1.6942157582592943, 0.9218514004458749, -0.823621328629307, -1.0203195530541063, 0.07206341884947509, + 1.214451058931207, -0.40454188129729296, 1.0066091638178039, 0.0801907243894604, -1.3250420419558173, + 1.6740542746900093, 1.1525840942242223, -1.7538751715267296, 1.2289357346449874, 0.44273632243705485, + -1.480515264351281, 0.7203216915034076, 1.736757457268701, -1.6126702540429125, 1.3353291202473017, + -0.3386246414186953, -0.15824184756675486, -0.9082701908024067, 0.19739090770367174, -0.6056382353292964, + -1.5021205949442713, -0.13004055376809376, -1.9680841369920756, -1.0085004366482355, -1.4660753620146698, + 0.5310372051600734, 0.6252656799282139, 0.28834705715856845, 1.9582690133559009, 1.0284365097891248, + 1.5791162728965862, 1.5890315798131152, -1.5740074592032895, -1.2346249557276874, 0.09869464843015763, + -1.9888659899819219, 1.6245510003031853, -1.8240356816949115, -1.1160775047609688, -0.9717920085031029, + 0.026054846056297265, -1.5990410562524549, 0.6191498034442082, 1.4318181978005144, 1.1449640789498945, + 1.435002701727269, -1.8991365043582489, 0.9679929619919578, -1.9806014397574234, 1.4536549482052994, + 1.2369898149063783, 0.0942097559544548, -0.44988290575276135, 0.6393419762034132, -0.6790983093222227, + -0.33932133359722183, -1.3323692893954657, 0.051614649124500644, -1.4850113224086279, -0.6288795685974646, + 1.6283138375987471, -1.7789341475324125, 0.9641353407036455, -0.8859758584446853, -1.3905521072955613, + 1.248121111974715, -0.148429098769566, -0.4602995590886545, 1.1003026112521521, -0.6827850627285645, + 1.9771039126131766, 1.227966495460028, -1.9054616874551513, -0.17250407406937818, -1.5976498513957544, + 0.04439303257666172, 1.51163887169748, 0.646404644877455, -1.721204116533527, -1.491970732730513, + 0.2989738629420451, -0.9550137794523614, 1.2025118418310514, 0.17278672215112767, -1.7073672613962216, + -0.6869493466963412, 1.4221175269968072, 0.8136078027936655, -0.09457917718456343, -0.6500956499585442, + -1.0403599385456666, -1.4197371934035345, 1.2190245805380533, 1.4761477510781154, 0.8005307377799946, + 1.5249741316932477, 1.0855648961952538, 1.9224829076060752, 0.2869108546051615, -1.5091598215762332, + 1.4501191712612789, 1.8870984229831107, -1.7302116354922958, -0.3545753940936196, 1.442437478066842, + -0.7823996675200249, -1.131407428223576, -0.5013183847734197, -1.4979613863674297, -0.24415013272007524, + -0.6627159008601602, -1.0181273908903332, 1.0328362758216585, 1.260250651113478, 1.0138156752044916, + 0.5371615864333288, 1.8553073783399325, 1.951357678034606, 0.7194607934648083, -1.0589630923826236, + -0.9620500497996112, 0.67763129815489, -1.6212754527879456, 0.8824272362591019, -1.8034693359581802, + 1.7422096047831568, -1.010660238079712, -1.8120346724245788, 0.7326343612924848, 0.8912492318672749, + 0.669928902828123, 1.5191540472733163, 1.9408210365763843, -0.2675149665538239, -1.5368478010395314, + 1.1378158248590227, -0.14340274845268297, -1.0420396266499052, 1.4238975359786146, 1.0186548966434374, + -0.37302282044451296, -1.6665521929489486, -1.9414179347758749, 1.9845099037831098, 0.25190468995929116, + -1.3565033558826896, -1.430084412385555, -1.9049229412063653, 1.221645795300951, -0.5219823891712627, + 1.9368378730562315, 0.7035479701902823, -0.3754047433995442, -0.14526093695990294, 0.17885379911634125, + -1.7453384743240203, 0.8465052490601339, -1.6823293227525946, -0.8161516604165566, 1.8312443096922442, + -1.9192510920287678, 0.37723942145518574, 1.725107407940751, 0.21381615826643507, -1.1716001801855649, + -1.0611345679669162, -1.9732355910502637, -0.4461777828323239, -0.23052068805350423, 1.218942575723136, + -0.22572812681057552, 1.9428383183668947, -0.06997110252961747, 0.2238383470247678, -1.178747770654164, + 1.289511874981887, -1.9756722906104285, -0.42650188553983703, -0.5494388087356263, 0.6619450518994654, + -0.8233262341940337, -0.551580552700945, 0.5817278377322888, 1.8685269618613036, -1.3000953227319512, + -1.4283838275880294, 0.5999358561474102, 1.0645958312240245, 0.014353697508937557, 1.2413161019277963, + -0.6897291610817131, 1.12297456609934, -1.8432752527139797, 0.9084575027035138, -0.1243916597867818, + 0.33130042087381195, 0.66895259903393, 1.6557830983974302, 0.08287276029065538, -1.5357530396168295, + 1.2343785017416087, 0.12902705924095592, -1.3495271524839367, -0.8728580889008626, -0.5244139433889465, + -0.3879006179068245, 0.9188768930207276, 0.7793805226934829, -1.4393071368258976, 1.613705376668844, + -1.0843958887438374, 1.7124014800074736, -0.9964777990871019, -0.8670020603088275, -1.779081809150778, + 0.12646904349631516, 1.6571805108463433, 0.12600417813102993, 0.49996900162664826, 1.9525284377831484, + -1.9652455426976498, 0.9494015016667543, 0.38443667884960586, -1.7724098362330505, -0.4082684743558227, + -1.7879879942659969, 0.08275123875162116, 1.7475525868209036, 0.6792010104118127, 0.039437186277630154, + 1.2836761397306624, 0.43674745284168726, -0.758347022092078, -1.7870493658991657, -1.7978809807072968, + 1.0584586913661846, -1.1157899056211305, 1.2701741054323072, 0.4374863843249699, -0.8214325870445185, + -0.7728421127264511, -1.5577282427526624, -0.23594830217496732, -0.5391955305244664, -0.5624792843906574, + -0.1938002353617021, 0.5223396103139288, 1.538375944848121, -1.9694136326451632, 0.8633772521280756, + 0.7208433609240901, 1.7666620769537023, -0.43532694305019426, 0.411879568237544, -1.2098790166305973, + -0.36423075107954883, 1.8891146830740784, 1.748054742284471, -0.24142532135673545, 0.5927554266793296, + -0.9152877032152276, -1.8063969967782656, 0.48207466271184707, 1.213706670479203, 1.6604107491663145, + 1.7558340275933135, -1.7932219074543312, 1.9856023833866967, 0.4913618217526574, 0.6789406744471576, + -1.2508048490603496, -1.3830358311750315, 1.217800508245622, -1.363187904834711, -1.9078698738816984, + -0.28177773798420525, 0.5582837083561785, -1.3049002248929513, 0.08913900413869147, 1.6980987887729269, + -1.1912978387544912, -0.8931915280846656, -0.18855808865450108, -0.45803053755017054, -0.8712638141863449, + 1.7094200997790034, -1.663338331843227, 1.2015315071432422, 1.7919441084478605, 0.31409629120692095, + -0.2597694212249051, -0.9168044364289045, 1.216621830723616, -1.6662413518480443, -0.8080880254762635, + 1.4073145445825483, 0.545639136243703, 1.4013255263406696, -0.5546097014600111, 1.0974522886316729, + 1.5316488607435108, 1.834526451907207, -1.9144935270991548, -0.7645822781961176, 0.06654825465795611, + 0.33190077251492234, 1.5544507881908975, -0.3538304760036741, -0.9572664339018262, 0.9528358413200575, + -1.6745715540252952, 0.507318989134915, -0.20165453900363595, -0.785857858404909, 0.8550126703052818, + -0.0012624172438071568, -0.9598028440876272, 1.2244820355277088, 0.2255077411064299, 1.8056720264962882, + 0.037756919076751494, -1.343540974005494, -0.5094599890534655, -1.9469287581463544, -1.9209187756131971, + -1.9807700974543714, 1.6311835876110612, 1.3993893579005068, -0.43399651753379054, 0.5299120477626555, + -1.9670191857782173, -0.5248122259552108, 1.7175353497229144, 1.7289010829074938, 0.5729177740933844, + 1.5352457320381498, 0.4553403180302915, 0.29793068106143483, -1.5797876923914638, -0.40357306624290423, + 1.265178047524694, 1.0072829549062359, -0.1769350094041604, 0.6540198866379798, -0.568023907344485, + -0.9385372669098198, -0.9547551373309231, 0.9690083247728642, -0.49448270214331114, -0.23248401423341658, + 0.8559950134661438, 1.5269704580997994, -0.6731018952946073, 1.4821498012885321, -0.086653738716163, + 1.5513273624777053, 0.4908490299053412, 0.608900001002695, 0.9347951454225933, -1.8764423825502883, + -0.5383639900183281, -0.9057735867914714, -0.4471390378994089, -0.9345110033957944, -1.2485918565222498, + 1.3231857428964897, 0.14498731120726926, 1.2862055594556612, 1.3345878394286323, -1.2756222907563037, + -1.3083824064683531, 1.5045780173000916, -0.3988401099234551, -1.8345796826462557, 0.9521512093444242, + 1.37476939261499, 1.7326371780327987, -1.1085753439494548, -1.9216876869475268, 0.32444794625794593, + -0.35462190108024316, -1.6069584401996142, -0.6840247431355166, -0.6774334428214628, -0.6167820243082769, + -0.7218673863781015, 0.4950955340603276, -1.9104643839392264, -1.8658682955390713, 1.6439842658028825, + -1.9665747219603489, 0.877772486257161, -1.447763157358751, -1.602252241120052, 0.44016487530366266, + 1.010332237856809, 1.1627057025546774, 1.3589381505180222, 1.6636605093405308, -0.30457004904529494, + 1.2269612708937165, 0.6454558582632632, -0.3281883947244264, 0.8131649340163314, -1.0598316735137594, + -0.42299178838843243, 1.2938482948248247, -1.9502364601006974, 0.33584798243720115, 1.960552892008372, + -1.5357606602986342, 0.7325579192102154, -1.7068807453175427, -0.27946385324092926, -1.8639767385643475, + 0.026428176272640158, 1.8041446105579393, -1.9753854510739615, -1.2061334845302873, 0.7372916274629793, + 1.2179795802972748, -1.3143781251446196, 0.682430888983494, 1.765010566861033, 0.7860745956875865, + 1.5788062358182353, -1.575242377037637, -1.6479554611466636, 1.2726425469891218, 1.3214340571366705, + -1.1176914021551418, -0.28104090150176475, -0.8644022366981297, -1.0077418811855443, 0.015316356685453059, + -1.8951979061169633, -0.5196653139834044, 0.8125953302477926, -1.823298987094664, -1.772714760583118, + -0.42385130914514946, 0.8535556140476261, -0.21002977910714105, -1.0489096773038966, 1.693320576211045, + -0.7552501574944035, -1.70051398713119, 1.7993076735058722, -1.5431556753936784, -0.44776937868310007, + -0.0034131973262736537, -0.08787851222156462, 0.1325414438300383, -1.9039783696720445, 0.9518867483514475, + -0.9603800917394327, -0.0648448360480609, 0.5218985531710549, -1.6635189965029777, 1.9924108002877654, + -0.26339506136937363, 1.1928641143061727, 0.3810845521550812, 1.1123779174325765, 1.2834151041594684, + 1.7494397068262924, 0.7406099190962179, -1.0134737497144135, 1.055889762638893, -0.8394462038526926, + -0.28477578596133046, -1.3393678408508727, -0.012496365797505682, -0.03748743063321758, + -0.09254003926345167, -0.19080760280617248, 0.34473162850229677, 1.3043786112928455, 0.33958283219176444, + -1.0099885582662163, 1.5669995819775036, -0.2714415980354179, 1.3346740594279858, 0.602475697679103, + -0.27309507620431717, -0.4346043241348738, 1.12017383631626, -1.6915194091713959, -1.1923394655016413, + -1.8515166462293298, -0.591963112025546, 0.5023793945680568, -0.4921032884548957, -0.4819170803548376, + -1.4966363280675248, -1.421677402946293, -1.6850285211377178, -1.9944566759402251, 0.23292198758882154, + 0.8553269921007791, -1.948229679796313, -0.13143026366451505, -1.6538966244541689, -0.9216045649630527, + 1.7264765451863706, 1.8222181199058705, 1.9306597140261852, -0.42071370405153363, -0.994060577839833, + 1.3461298706581069, 0.6155835113186514, -0.4233898713404436, 1.0628290316641253, -1.882099316617622, + 1.0589656835123407, -1.1563854484115064, 0.06709310867074869, 0.6384730812451886, 0.10574897066907951, + -1.702659493082189, 1.532461866391916, -1.4099555819463054, 0.9274818832867693, 1.5037867997803227, + 0.7264189517035895, 1.739697756396633, -0.6088742760292636, -1.6856099921156718, -1.772602909046058, + 0.3867733521889116, 0.43414404525556716, -1.4477323883785695, -0.5550413636141851, -1.0684298188676138, + 0.49963854275697717, -0.6428666361617692, -1.2461600431717166, 0.139396634056558, -0.8713484091078527, + 0.9279324928307542, -0.860886224082325, 0.6052214303020085, -0.07407732844211701, 0.8385159528403303, + -1.2883611258333367, -1.594397303034099, -1.288952302189137, -1.5793499078181599, -0.6082222788183742, + 0.35912648776686495, -0.607260423160894, -1.253926812856471, 0.5102013763540754, -1.6947762434136582, + -1.675889179914285, -1.7314605301275563, 1.5448045270744153, -0.6686717741699395, -0.5356827185210964, + 0.12358215343481938, -0.6340124361185859, -0.2817760753415781, 1.164433875881067, -1.0118173150641265, + 1.8268866563720367, 1.7413521495755342, -1.3276318599091654, -1.7238317321272323, -1.7921370418547173, + 0.3558056543449375, 1.6918022479955095, 0.9222053317838856, -0.052028097382029515, 0.5122787494307435, + -0.32626022494247575, 0.15032399126484997, 0.8080660425425252, 0.8796007677651145, 1.6881141214849356, + -1.9923518519205325, 1.1594510791129853, 1.8780756936993601, -1.7055973395367428, 0.7022611016052407, + -1.2390916075946476, 1.489326601979406, 0.09129782431581734, -1.0503781321438632, -1.830361985060848, + -1.6234239674810462, -1.9400691667730579, -1.775525052900898, 0.49354389704319335, -0.294835573216341, + 1.6593428114830902, 1.1178351651841174, 1.392043569467992, -0.23893431171828805, -1.4623844403945752, + -1.5343016105239773, 1.1865854046143252, -0.030035182509464242, -0.23319636854325143, 0.16118837392623053, + -1.4249606736799842, -0.10348851980420637, -0.512221808349385, 1.3638877591460936, 1.6510345373891697, + 0.5403817230171999, -1.7568167360776457, 1.582586423119313, 1.8459420743954364, -0.0937741201677813, + -0.07681514921378785, -1.6485771749332665, 1.405045845579937, -1.380639120572158, 1.3920918243796292, + 1.8133826455276356, -1.6212352715972855, -1.6619339109060558, 0.5860174449496753, 0.8058352539157623, + 1.2357143629814047, 0.22041032204430522, 0.7808171688721632, -1.8606664108725086, 1.8543721730560465, + -1.2809372837528121, 1.8923388155485306, 0.7298982025806318, 0.08659163513354162, -1.7687987574607682, + -0.4763155379116144, -1.5473458996242258, 1.5866611468929568, -1.5386305359062193, 0.2081267635605304, + -1.381273405047125, 0.29199706105659384, 0.5721917482420489, -1.6925391452117218, 0.7916231259738984, + -1.4231755623012, 0.3405385083200416, 0.12077341714245371, -1.203092512782769, 0.9840212059218176, + 0.8659573692212614, -1.7404045592057091, -1.200177869243232, 1.9929041194760506, 1.6176296085808595, + 1.15438166620096, 0.41001019693662766, 1.545005651638757, 1.4880131034252333, 0.6483595755736316, + 1.9934284869908785, -0.4425968131492626, 1.1431733762266791, -1.5542365498656396, -1.4707761026218984, + 1.6864275163379014, 1.8816548331946308, -1.896526297348939, -1.2767368243378723, 1.7876805565314937, + -1.9714294114028945, -0.4825746666154789, -1.7490480229694754, -1.6790035262680618, -1.89004754358142, + 0.37907912536042243, 1.2624866059593396, 0.10773840252143074, -0.6872687299448277, 1.9925905119168483, + 1.6613584791100244, 1.3497294256582544, 0.14770827679848963, -0.386793508812298, 1.6398868686034165, + -0.41758827391224607, 1.7067458043950365, 1.0278221550032498, 1.7159317753776575, 0.5096805329124043, + 0.9834561119760092, 0.09498575572277979, 1.9951803301511655, -0.378982487863734, -0.8263817670300764, + -1.0229038853801748, 0.3913502332282688, -0.6996774883339203, 1.8060537458340287, -1.5307723308680314, + 1.8274255271921316, 0.3374273766186473, 0.504303338314112, -1.6080869793956403, -1.158919549492441, + 1.5813373512787257, 0.9432537197048445, 0.7842306706160116, 1.3546056018625912, -0.9933122900788947, + 1.0300108626811424, -1.8830535281855107, 1.2388444130694714, 1.5820432556020423, 1.7375339367053177, + 1.6102618826367072, -0.8665078865132916, -0.6732392396345555, -0.6529773281138276, 0.10547233393450206, + -0.6941730912707627, 0.910541741327461, 0.9098240508320066, -1.6363195940333455, 0.5463776877988993, + -0.33425194180314666, 1.0461836058570837, -0.9752798490633703, 0.8152462082002794, -1.5875859561094057, + -1.9077898642385165, -0.6579383417965277, -1.3678615233634694, 1.285752623992095, -1.258025248987475, + -1.734869037268262, 1.0429220147507374, 0.7326509890514723, 1.9986880192758498, 0.6509156089882868, + 1.3547313230945077, -0.5885137816765393, 0.17851497147549722, -0.48206201174941743, 0.7757155223228231, + 1.5797039076642765, -1.8365491055202483, -0.5557915214323543, 0.8098649490633258, -0.9482838270042748, + 1.7240038491867988, -1.727136737786945, 1.504375516257281, 1.8508968488624733, -1.781048252815335, + 0.29406576456043076, 1.3929052384310125, -0.5672101671346796, 1.1612683698036195, 1.0365664134132775, + 0.52678187508665, 0.3071827595777501, 0.7041562182045773, -1.1798720159212701, -0.3576278976339484, + -1.816947298419736, 1.1348446198592024, -1.5847397567297339, 0.3604508734069167, 0.46593859852538255, + -1.7777395155730495, 1.6287619990007203, 0.05390138452125459, -1.6910054822794427, -1.1955797005332123, + 0.1527150449913588, 0.7396002223852989, -0.18108018677808335, 1.5581070045362102, -1.8797418765958547, + 0.8650557162994215, -1.611948653627663, 1.1857525788641103, -1.0890741767763679, -0.7623497273737962, + -0.5452689372225361, 1.4851221090925062, -1.5985252342075071, 0.020384445673001572, -0.9046335836897486, + 0.553216643980825, -1.9081035245013105, 1.1735365048310928, 0.1195026810678641, -0.13729942676911122, + -1.5259805046535062, 0.18920225926555112, 0.6705947977532496, 1.0030469424172317, 0.7610916391668443, + -1.2463291870058537, -1.0249840040963818, -0.19173741736380734, -0.6665958901111466, 0.4839557261866174, + -0.3485528787244352, -1.5242944187930236, -0.14435052683644933, 1.4529331704829067, 0.6458683077621643, + -0.20459459850901585, 0.631062246518785, -0.8791486672647375, -1.2386308323664368, -0.7619627858914138, + 1.2107017989790583, -0.33473019147897354, -1.9130923107708755, -1.7056520026340571, 0.8880970113501672, + 1.3775061802643274, 1.6842143744397662, 0.8956663807740703, -0.14207737810877585, 1.15151215723502, + 1.9554721163330937, -0.7192433376047749, 1.9055716193191357, -1.7324310669463001, -0.9426609664206191, + 1.4304692916943642, -1.4208060554924176, -1.6703510665846304, -0.11789403764850537, 1.1035841327225944, + 1.19193146120024, -1.735646676252749, 1.9517684584286812, 1.800446091627931, 1.2643875465217338, + -1.2016658470474093, 1.096389364153712, 1.5201049232112753, 0.30449137858893494, -1.1869295719311896, + -0.8315768582782672, 0.4346498942421677, 1.8028550790017723, 1.8857086005150672, 1.1520127103731062, + -0.1519769293665103, -1.9810749314865586, 0.6722940736723881, -0.8875375147759952, -0.7891504144449426, + -1.6946353432210914, -0.4359020089062122, -0.6980599672246557, -0.07982954815920884, 0.19084722828248868, + 0.6845828835105898, 1.361502619024912, -0.6198314937755827, 0.43206457200540616, -0.4275542957416594, + -0.5942951907386576, -1.2680930727916406, 0.8768595741043272, -1.621734829642608, -0.5341763533991353, + 0.5480403433778003, 0.04432939753449716, 0.6820566847717622, 1.9586624282689744, 0.32087743739982866, + -0.6031550453761696, -1.9650369588045296, -0.01282494430097092, 0.39651817911017506, -1.3672271702505698, + 0.5253796884817046, -0.9414239928629815, -0.8381706914260612, -1.9383642756327601, -1.035705422287875, + -1.206785655560016, 1.2381881481085397, 1.2976332585562043, 1.6406249791463123, 0.10308205600753784, + -0.5475238811744738, 0.29899086302238675, -0.8482038855976217, -0.5137345687600776, -0.9065955517315878, + 1.3104099207076656, 0.5025101714926583, -0.19511985691542133, -1.7979503287642702, -1.902372134245554, + -0.8024648307653406, -0.8296523085064926, -0.9765719040493881, -0.32443266986702035, 0.6539733476893286, + 0.5973369569721942, 0.9620679364522129, -1.1326092851006697, -0.5203786703966227, -1.9445684352154542, + 1.120817472266788, -0.7832944208518358, 1.7728569415904172, 0.4746637992232019, -0.6331056497466134, + -0.9560458440421451, 0.774229832283905, 0.9430691283749084, 1.7415084291941643, 1.7617290689369067, + -1.241891773067345, -0.4745318258222131, 0.170886449126332, 0.280332317095926, 1.797798588989031, + 1.4955462166754243, -1.9714454758512092, -1.0474834990256507, -1.4515647324997412, 0.5231202814417619, + -0.8247619693573318, 0.842919528017017, 0.3442132040114725, -0.7990310505752118, -0.3629900475529535, + 0.3295033891840484, 1.7075620006843595, -1.7886707971887938, -0.4229652804748518, -1.2955282842854743, + -1.178898100821427, 1.0539302135104265, 0.07818911036924803, -1.572166269506143, -0.9256373147747796, + -0.4523743452766471, 0.870552332090627, -0.14850807893624118, 0.4287510203485754, 0.7234083474785322, + -1.8786486629917283, -1.3259155700503378, -0.347806801338562, 1.3631756709064557, -1.8448244492904076, + 1.0352554853447788, 1.2734756957434143, -1.8995503975543802, 1.3929138458956896, 0.4683038046394792, + -0.2679051570033124, 0.8144995104717001, 1.4932193857012948, -1.6953990593425603, 1.402756253554589, + 0.3808795564720828, 0.8990244719837941, 0.40911995577962923, -1.1361078826263524, 0.17385823702487802, + -0.8186745793721117, 1.7076078317166345, -1.236408907822983, 1.6848698673082323, 1.1057305891182168, + 0.684172881489129, 0.6339043483376958, -0.5196970144327304, 0.3521577727883862, 1.650116506377742, + 0.2682799961840745, -0.4735884200746394, -0.37060613769970896, 1.1147917782653982, 1.7887865852421685, + 1.6978458641851324, 0.27152326646567104, -0.7479948142297932, -1.2016639963842959, 0.49444674806749767, + -1.3781311297932053, 1.6512147519128018, 1.7416186949120336, 0.045136221165612334, 0.5209559439848652, + 1.259816622507529, -1.4608419771090535, 0.044609757110229076, 0.9224148305655282, 1.001777160557932, + 0.6923994259395219, -1.7332027007479764, 0.1740919172843176, -1.6691024022127667, 0.020765839823362775, + -0.5411621161873468, -1.9512100832757255, 1.2205713100210618, -1.1974515945322022, -1.8219003860967407, + -0.20156369520917128, 0.8953615730514599, -0.6851191939445931, -0.9784599914417944, 0.6357787809646211, + 1.2543221592701617, 1.154804559643825, -1.3175939582140597, -0.6021071761808416, -1.3594268073464137, + 1.445454919786025, 1.6373224793127816, 0.4773095450916047, -0.5842274784088719, -0.7097055492598594, + 1.221870791596726, 0.5841175201744218, -0.6088866249685312, -1.0239540299091434, -1.5776221817975893, + -1.0528530592443435, -1.7599610522525682, -0.009359693820525372, -1.5849216381687343, 1.1123817614821956, + 0.48090947307860965, 0.5755896514757204, -0.49340183569225626, 0.38999288587119985, 1.9076108980193105, + 1.7293269863481244, -1.3443206598919062, 0.001652373471082491, -0.3612982015066102, -1.666179661360145, + -1.5652606963428637, 1.2847799597537284, -1.8746367656966019, -0.5051948920106453, 1.480902450916811, + -0.6574579095355038, 0.41299775479925227, 0.0038296277371880905, -0.49511696025555096, + 0.33076498894339057, 0.07531759452269782, 0.3925650285564686, 0.04553443036132965, -1.2835239354793773, + 1.022597041467466, -0.14452142564653414, -0.6834698246070756, -0.26296934920727555, -0.1553100097881268, + -0.7486942835272403, 1.6436947866485108, -1.453108500403955, -1.142814165110103, 0.8561776338025142, + 0.9114113939243937, -1.6289171667242108, -0.8765736035046707, -1.3197920400322358, -0.3734912789205094, + 1.5370462570492807, 1.0649239894263536, 0.8535930082816217, -1.6961129271355224, 1.4179270256764838, + 1.4041314424292972, -0.0700614584636794, -0.5005397839840926, 0.23183520631057597, -0.10990588024738823, + -1.147089599229715, 1.8964022235002318, 0.03743472262799674, 0.15742108441832148, -0.7185738786020632, + -1.1856642140796065, -1.044955084073341, 0.6700862122512099, -1.6025569167524276, 0.5441522921772082, + 1.756702564317389, 0.1676615923964384, 1.9361822628230643, 1.7064197592361863, 1.6880082988048626, + 1.4936942432907419, -0.5799520435760606, 0.990660775656627, -1.9229488942439295, 0.3155933315458741, + 0.27140426916735727, -0.5632551338077167, -1.4590940918768025, 1.5983827694310806, -1.4330645192158764, + 0.8003275119389475, -1.9841806470403869, -1.8395944973457397, -1.5552105471701392, 1.6842460594625583, + -1.7017713102134175, -1.3427034266215978, 0.6775219782398878, -0.3351302524035349, -0.022448660326078063, + -1.3500598715209335, -0.5202469644373027, 1.921619202644619, 1.3872706825409837, 1.396529188979435, + 0.07083965303226236, 0.6914877624468483, -0.29224764274698156, -1.751452537465684, 1.0064318828131888, + 1.5123878406285325, 0.4033025380645654, -0.9820341018265148, 1.7861648401668955, -1.0181815370408254, + -1.903647522119213, -0.6561248619205413, -0.8699254533593299, 1.7173299560789754, -1.9229485367712744, + -0.690953419069066, 1.2737422710672446, 1.1066566790733132, 0.8606336717720096, 0.8621209988717888, + -0.10083318929855434, 1.3453155024821513, 0.10215147045998396, 1.3751499812523518, -1.2192061255665605, + -1.0236206808587962, -1.533123935767816, 0.15060837984637665, 0.20593075721578913, -0.1420114293640644, + 0.17633888228692118, -0.5175320445499185, 1.3087124943884705, 1.8324200965520019, -0.4738758668846961, + 0.5584879952099433, 1.9477363800804062, -1.8628690990218209, 1.42344054481267, 1.9288039965215793, + -1.8566049760751664, -1.2884135845361993, -1.19038243691399, 0.7683237483024614, 0.4802985717653341, + 1.4329557197969898, 0.9128985393650337, 0.9461407015680665, -1.65078779921165, 0.4192615449543542, + 0.32096438114763437, -0.5129450234494479, 0.631644434863107, -0.47335834996781756, 1.3100891597589213, + -1.4790022015467565, 1.725986057822876, 0.8701026300053698, -0.6819596038865807, 0.4664857353751577, + -0.7112013772427535, 0.12572492597099938, -1.7769733290901675, -1.6440093209585278, 1.125533492349331, + 0.19218031047016204, 0.02403105386135085, -1.7090825450471296, -1.312565935939367, 0.03858791868454148, + -1.2641036387182139, 0.904155292993349, 1.3561478854255968, -1.539635445223043, 1.9235783529452943, + -1.874310049415433, -1.7773733191470518, 0.11118483759793829, -1.0886413092825036 + ], + "dims": [1, 1, 2560], + "type": "float32" + }, + { + "data": [ + -1.0123639551600512, -0.1262791332695521, -0.5528788189121379, -0.9891722578914903, -0.8541177111117033, + -1.842327110585285, -0.8753504664996115, 1.645494290881846, 0.12876899266900654, 0.5739158499727086, + 1.1027206966256946, -1.3155458981065662, -1.1433211051679475, 1.4367855916029315, 1.4674402192278242, + -1.3373231059554618, 0.8170172046647917, 1.1074697240531268, -1.6004007249577086, 1.4644646696571568, + -1.827927680383385, 1.4548611857965401, -0.8614990138298273, 1.6706016048131325, 1.5096827794979042, + -0.4651782953949448, -0.8577219028996828, -1.299490913831483, -1.339145989117756, -1.0017632367283333, + -1.3586772742922255, -1.8799261724983776, 0.059417093938870735, -0.6646734157727456, 0.5388638764003799, + -0.6378909942726629, -1.1647516486356562, -0.058721485739401835, -1.814477796499844, -1.189167849669282, + -0.6380012350722168, 0.5285662507696332, -0.534701982091482, 0.5570990437739303, -1.755585696977759, + 1.27238726136709, -1.9571057028071568, 0.04195657651183726, -1.9047024137637942, -0.5116506294039525, + 0.9926189908254566, 1.7759871807627992, -1.301591689492625, -0.9524123108659142, 0.524043944088068, + 0.5401946307447156, 1.2036398911372181, 0.3219194319137575, 1.9433711281899884, 0.33648919584354076, + 0.9772519308773946, -0.5575502080369272, -0.7345410843307336, 0.5778449333097511, -1.9408006240005902, + 1.9819202164371932, 0.8468700855540847, -1.3899691404202503, 0.15850835128301544, 0.49059781858603113, + 1.7764286491001, 1.7946165130578535, -1.16168298050607, 1.032789240397304, -0.7706982026329863, + -1.1038032957670127, 1.1838096309519006, -1.6323318982074788, 1.1064057844588042, 0.391546384023683, + 0.8726810963884386, -0.28563916045025906, -1.960002018275822, -1.3867181381833626, 1.921900210546779, + 0.23545042298506935, 1.84976268271061, 1.0525315891685612, -0.7472942377034979, 0.7577710804753881, + 0.3083793892222504, 1.327301485825914, -1.5659874146221533, -0.8978311834083046, -0.62907789098844, + 0.4870403076019034, 0.630783738162723, -1.9216326727288857, 0.9012236985202584, -0.7852645198565309, + -0.6937624671663629, -0.17653350051115613, -0.5561473457869717, 0.6698782481609369, 0.23362849702887267, + -1.2054404784680655, 0.7698259603739315, -1.1714183267177214, 1.1184512448333477, -1.6690588449303805, + 0.9147154841361029, -0.8955722282828678, 0.5231382201829353, -1.0773616491852902, -1.8410207920577744, + -1.1337885739036437, 0.5219438601865445, -0.5731516829203676, 1.5757208446602124, -1.8470713081756802, + -0.14985389360149082, -1.9692981274100303, -0.9532002408436595, -0.9539842568916299, 0.17378022347200517, + -0.41623035688267596, 1.6481595173848254, -1.026339675363599, -0.9699532421892103, 1.9461597162595536, + 1.3640648912196438, 1.1391500889364954, -0.5595329932725566, 0.43069270118926184, 0.9216291855227485, + -1.440051073691266, -1.47987236009598, 0.9087155921703598, 1.2787682017568338, 1.2363303394454128, + -0.49711585260182556, -0.9387884495809562, -0.19011148798368716, -0.08611760612433539, 0.8656095244085549, + -1.4316244544821588, -1.558915825119862, -0.6738541489070196, 0.7358531504667214, -1.5716157542957179, + -0.3210549969144969, -0.20072745672949832, -0.2416365577548918, 1.5081023632879136, -1.547604775100064, + -0.6334244403609635, 0.14810360033380032, -0.7978288773497635, 0.6204344672116999, 0.4773642826492761, + 1.2087249539596163, -0.6643075818626354, -0.4170560884596144, -1.6192321024457321, 0.7844847722786517, + -0.629690651133866, -0.7380758723814482, 0.303414620658204, 1.5479875220490822, -1.198103302774089, + 0.9760982659095188, -0.7574001500859886, -1.2724614749813545, 1.0658176639069543, -0.8843666652730136, + -0.6427064732600343, -1.4416669867869603, 1.473657450100351, -0.8344994942004691, -1.4224435385472942, + -1.0338023533751777, -1.933568422908838, -1.0802998481520287, -0.38091309180010224, 1.6199945117010506, + -1.702101910236685, -1.4725385504086255, -1.413591341417039, 0.540278745507603, 0.3517718642795238, + -1.590795907883174, -1.765499823368284, 1.7366923492439614, 0.4582221558773192, -0.10581682008337268, + -0.18516227544796227, -0.21097779387988158, -0.1428735544745896, -1.97510241493318, 0.32449001731988947, + -0.1832218746003349, 0.26181337286546746, -1.1227369967165552, 0.35351574098454375, -0.21205956319428498, + -1.3866497212089195, 0.5946688412485415, -0.3417425538750871, -0.33633058083047906, -1.7852940300594273, + 1.9919312461265548, -0.7629882135863388, 1.4620310920385196, 1.1115061446942711, -0.9057539166302142, + 1.775862903430335, -1.1324751374031425, -1.5851970376790732, 0.9843604936500894, 1.5734177841900427, + 0.9515914445205862, 0.034323622285483246, 0.8075573695504703, 0.5332420240003275, 0.7767308358623826, + 1.0329131214994085, -0.9838298807872725, -1.7429868813963063, 0.03740922197745089, -1.3794671490283932, + 0.9772124799843054, 1.6546060756751624, -1.345806362676182, -1.620585515255308, 1.3498272448019941, + -0.25283974040314394, 1.8309785362540882, 0.8336568766196351, -1.407378961144727, -0.8870392599067882, + 1.5455801491463914, 1.1404840595611354, 1.9778865957841072, 1.9026326233043243, -1.2919286899267508, + 1.5194536255103763, -0.40024201189426734, -0.38767200629106124, 0.37883119550528654, -0.8971148848399899, + 1.1472506966060552, 1.6769048658537242, 0.39834963390946676, -0.8584979863189526, 1.4851684858856853, + -1.9898922021489387, 1.7271323520062838, 1.8848497191146922, 1.7439889790675318, 0.5311134881425099, + 0.2302960173495876, -1.6217864910988453, 0.28492260856667784, 1.3550969896689358, 1.6762026245924515, + 0.45464402605973575, 1.8447468286497628, 1.7489125819896838, -0.7745650567526248, 0.6473255323813127, + 1.88270574713889, -1.4231865592800421, 1.406236181817195, -0.05820536366515672, -1.9830176146937077, + -0.927096735728453, -1.37521200952383, 1.6293084827507869, 0.18916714867483186, -0.3559388864834574, + -0.0626685044384443, 1.4510888124049117, 0.00665671994549033, 0.5852250937009087, -1.964947150735517, + 1.086994355276114, -1.5545621604146378, -0.6702039017668291, 0.15273130009205538, -1.354848404243989, + 0.8081822753111396, -0.2990136329330131, -0.1334268300545549, 1.2936295445817017, -0.5276761138383153, + 0.06209853112125252, -0.35227980331045927, -0.6683541541821878, 1.5365781152706175, -0.5227637702649135, + -0.43751245897261537, 0.39166051967309556, 0.6145502882685348, 0.6764920150128493, -0.46478346293163764, + 0.40093484640123567, -1.4385605602950564, 1.5318810200296449, -0.7902920012169599, -0.22815329205907098, + 1.5159148518766017, 1.7440445423086697, -0.7705868478778743, -1.0446035338845894, 1.4407728607631372, + -1.7690868678646723, -1.9956594357087072, 0.9165504950260104, 1.1647979922386025, 1.7626373785022524, + -0.3262003585763962, -0.7291462643423232, 1.2691368673965409, 0.9833027096614515, 0.7052758987187504, + -1.4080008451270958, 0.2004861907693547, 1.92413536100345, 1.8633978379666134, -1.5597901041000588, + -1.232525418601906, -1.9326471509575835, -0.23851047841947803, -1.745957663852197, -1.027455630245683, + 1.7842373183009093, 0.7098705198166604, -1.3523419086313861, 0.2493915779920206, 0.5836072040016118, + 0.8452857075528275, -0.6044200471234227, 1.335947146234287, 1.9535634253874816, 1.8737477649440653, + 1.6787256628480751, 1.3475059469256392, -0.9023420902836907, -1.815324493360138, -1.7487338231501415, + -0.08107787176718784, 1.178869071718574, 0.5869021791922719, 0.1289991861916615, 0.9871714466975163, + -0.7828180891664971, -0.9162265218952319, 0.7883323334301799, -0.7738825207321494, 1.0578051800827781, + -0.48483804389576335, -1.7003938250158095, -1.7474401518911815, -0.6024807198720463, 1.470072074418848, + -0.809852698248462, -1.8087803758512981, 1.1275613461510172, 0.9110052554791794, -1.4827836388852713, + -1.3641845213240744, -1.9188108209559402, 0.8859208024949954, 1.7438050845669144, 0.14476912919518536, + 0.6121128834023981, -1.7692670213619586, 0.023661688752081744, 1.1007625036098432, 1.1758330104763122, + 0.4546664062325476, 0.022499008403786824, 0.8120850018523171, -0.7886059301759083, 0.8107426171843777, + 0.015751759753425354, 1.9515003227015306, -0.3285612629290764, 1.758730602588753, -1.9178063288185045, + -1.3319225925070368, 0.5970900239608552, 1.8634221473873263, -0.7483844730402502, 0.0851383845623852, + -0.10037389959678844, 1.8601880295663413, -0.5358906627108242, 1.5311027975069011, -0.7567148434480719, + 1.7810484758849983, 1.5004941791198378, 1.238866744077014, 0.019238796977725237, 0.7314924609545477, + -0.6404106749076366, -0.30544348502988683, -0.7754562102568752, -1.1903829239480253, -0.7557972926946839, + 1.418804956107497, 1.6841275666684883, -0.35403092145419013, 0.3072276436064163, 0.4941160183076647, + -0.010460638654985033, -0.7496577784263767, -0.05957826320949966, -0.5349743628709929, + 0.44780861823397355, -1.9548584156880642, -0.6834407857845042, -0.6574778495500677, -0.2568872307434864, + -0.8179424074332058, -1.9399599284052886, -1.7438777236599172, -1.9046697213047699, 0.33576417528481173, + 1.0390831565369494, -1.2867357835981865, -0.9105779330773034, -1.2600968940701254, 1.7546113033912878, + 0.8638193166816803, 1.915934439034265, -0.18936860893703056, 1.6490561383957179, -1.7404200826424407, + 0.15942118157817387, 1.174512061322961, 1.087287672904493, 1.182852158431765, -0.12430741231511089, + 1.6711861366379157, 0.7124940145742862, -0.3946246470773911, -0.7754542725640272, 0.9539784330716907, + -0.5716889107776746, 1.4262896723570924, 1.4675456840569163, -1.7077488525524833, -1.6888666810589683, + 1.2108429896458865, -0.30524840414522547, -0.18167408305726607, -0.2569749511019337, 1.2912167486614727, + 0.6208472747047127, 0.9472464500515958, -1.1302136544927563, -1.478282134349313, 0.4848945322578242, + 0.8298435424742152, 1.6932133553283863, -1.4458048451455756, 1.4088139925156833, 0.505348371415975, + -0.21105001864112882, -0.04858175142791943, 1.3570555900503694, 1.2673714070205957, 1.1469844853077413, + 1.2000011591622064, -1.0780533577358637, 0.37698814259115565, 0.6997609434842227, -0.7604196150995675, + 1.8410835681246196, -1.6836663805912915, -1.6352482015283725, -0.4811456756273813, 1.3045848106454878, + -0.823583139102726, -0.646527859067727, 1.5092372843244393, -1.0424659042584983, -1.9448695809676995, + 0.2678845821239566, -0.6194354091338141, -0.3172475643478627, 0.16481577119936563, -1.1026554846901258, + -0.8352503899270465, -1.8149755146432849, 0.4839677208020152, -1.9367901959501284, 0.8680859459275245, + -1.2035834761537227, -0.8748603808576707, -1.8417628093555134, -1.0429294120821577, -0.2520578638761588, + 1.7833216800296539, 1.4367696159460968, 1.8669976567111535, -1.405562858069989, 1.0576377264778563, + -0.4569713929987014, 1.3255011842556819, 1.6171166029195225, 1.0403552739195874, 1.2264321768656314, + -0.47396544132443275, 0.7118492170263346, 0.6260191876547241, -1.2179712214091793, 1.5120789908822676, + -1.657525645319189, -1.2991286032659461, 0.22202239748400387, -0.5389051448124622, -1.594992260705033, + -0.0487688918807363, 1.1512759563478916, 1.4679486318272383, 0.8813613284468369, 1.4328674044139706, + 1.9999268579039367, -0.47950159568339323, -1.2281571927849866, -0.4554451947054856, -0.15429012922622043, + 0.19786052464284776, -0.3680312279906497, -0.18825645901610866, 0.13700608028054084, 0.11417316734012051, + -0.6463349589003959, -0.3770634118502656, -1.9950465240002844, 1.4676192894281632, -1.6060448800367215, + -0.6182395877160713, 1.2695963682598732, 1.4459649727588744, -1.88317964468765, 1.4240536704934144, + -1.5317465035623874, 1.9497915396745666, -1.991995390985421, -1.4828801801030478, -0.03471214257257316, + 1.0775554630217314, 0.38086611278178495, -0.1958126129950628, 0.3711869657718534, 1.7307000355063105, + 1.3370240564962907, -0.6892941432270163, -0.6176252997554714, -1.2761391889511922, -1.9463694295411074, + -0.820841598715079, -0.42139807880407787, -0.7976620746519121, 1.934915457164431, -0.4497028214466532, + 0.5258289450636102, -0.3002339930414486, 1.4317770429007526, -0.10670432773391081, 1.477187387671167, + 1.6422292268536607, 0.30393726544465505, -0.16649028812518285, -1.1690831963968895, -0.7043985973685203, + 0.47350023974648625, 1.657836032561474, 0.16219091081621606, 1.2307861721904754, -0.7270831655242516, + -1.0574142264570137, -1.9134290652692378, 0.1585901752075669, -1.922458865955397, -0.5216475421535529, + 1.4438431375673586, -0.2874803852531551, 0.3370681487492808, 0.9173203757850725, -1.3751125170831138, + -1.014212305918492, -1.5475568694685897, -0.5834419852252983, -0.022709263779811195, 1.4145035718404255, + -1.9267536984073965, -1.871498186038706, 0.525620783057489, -0.48663912480579086, -1.0308451661848164, + -0.1369351560707246, -0.7876105221422698, 0.6955249722293555, 0.29453585260072757, -0.06514154829755281, + -1.6429080966047698, 0.15520901396599296, -1.518429991046033, 0.6839405241853216, -1.8300625086431346, + -0.15898426442765246, 1.9278290352285792, -0.7150445644634695, 0.34034454186145346, -0.2506887167667946, + 0.2912251513885442, 0.10434269155791664, -0.5637887420304368, 0.031008416285043694, -1.0816174134360272, + 0.05203114530680164, 0.03172694813978172, -0.7646387549793285, 0.36213414786228526, 0.0060869909269349876, + -1.0367311632092022, -1.0684702277942222, 1.1874407786461294, -1.9290032593242623, -1.8550268296137276, + -1.161269082907582, -0.18656240236501098, 1.1070044180055767, 1.6261946461581385, -0.5698373516978554, + 0.9631347920513802, -1.0985201941795912, -0.45509125721838917, 1.6535643092193615, 1.9696290288271951, + 1.0266341388473261, -0.23790773845234447, 0.2828088466454748, -1.6537561622154024, -0.2286353308074256, + 0.8766588558049788, -0.8195788808423616, 0.7718354518479451, 0.46796484124644344, -1.7212327769837446, + 0.0658435971926874, -0.6407624425160288, -0.5647885630487526, -0.5284202936353299, 1.8818650199438771, + 0.29252062160862025, 0.09136912052125101, 0.4321630239196912, -1.1485094277982304, 1.6036235307678686, + 0.3318334927588493, 1.7219946936827109, -0.09166860362313312, -0.9321185046623404, -0.5842230824766759, + 0.5762857089716649, 0.6237761258836967, 1.3257989135149089, 1.65675048758645, -1.0060167288419342, + -0.08448091333478214, -1.2793076427969634, 0.7514972175750367, -0.4193024725154899, 1.427794959994305, + 0.9558973375817734, -0.00039143542951691757, 1.7030425931606343, 1.8219801925309609, 1.7260980421968792, + -1.9249357614979115, 1.6285038041870772, -1.6118493301059527, -0.4294907666236245, -0.1659993953929053, + -0.5722726383494532, 1.2935829105972072, -0.06859448172655114, 0.6177602091273879, -1.3370886026529494, + -0.8003712871381898, 1.4776462171750593, 1.4184671800982604, 0.5433276418773598, -1.1103872044287346, + -0.7572146251109908, 1.1710857107940438, -1.8705799333769377, -0.31024900334903194, 0.34866139709491595, + 1.4866168061361806, -1.9774625782466435, -1.9891386648785518, -1.7735018436923857, -1.5751766748778406, + 0.7218521749338684, -1.9390531947989702, 0.10502871018805493, -0.6908737000286438, -0.583334654840761, + 1.465181746808006, 0.9232784443998208, -0.37400862804876933, 0.6661364913985333, 1.8688403166211724, + -0.8717922772171374, 1.40123243258902, -1.9913513342271694, 0.7369262986601814, 1.2562396521830914, + 0.7638029152143444, -1.8164465814226718, 1.7184240047901733, -0.911895923498772, -0.43161449539234464, + -1.09011215721989, -1.865570383600069, -1.2232212962752618, -1.943030725162366, -1.3198980808588407, + 0.0564583162685901, -1.1298601037432707, 0.6392469941959655, -1.8442933136946733, -1.0692331296192306, + -0.29525834417416963, 1.1184299311108798, -1.6129180448223925, -1.2727580333965411, -0.9415967651718447, + -0.2597646669604643, -1.511150740860189, -1.769101860129168, -0.9185489030191736, 1.6841872338621604, + 1.7266136417112579, -0.6047332956995355, -0.4784036452377798, -0.6987121488961696, -0.3950169895430573, + 0.29099877073820757, -1.5250611710167732, -0.5876293953125105, -0.5938486494168753, -1.0021656820999798, + -1.9666708037201044, -1.272140943592933, 0.7880251982149247, 0.11964755378902137, -0.3901422866566291, + 0.7163616669643105, -0.04395212244207691, -1.0791955402155438, -1.2675936648298745, 1.0655012879795382, + -1.2960016160353804, 1.1268487593724137, -1.4561611267402474, 1.7853064994708516, 0.9817883639607627, + -1.509195648143982, -0.5791158763214721, 0.2952226835939813, 0.978029471962218, -1.7020877610480865, + 1.2949154364706787, -1.8669978207007674, 1.9804087372291983, -0.1920592681769353, -1.3129464854527964, + -0.13084211958976688, -0.25279655392730405, -1.8414897577067046, -0.7363208735547797, -0.6909260581968182, + -1.8811392695885178, -1.5901068180742568, 1.758878672856656, 1.5387787983193055, 1.6713805051822828, + 0.28500585759464503, 1.3914306792968247, -1.112480424362695, 0.43326162263712487, -1.8142315585546145, + 0.04023793859339708, -0.29805331377366073, -1.5940199056342657, 0.598129666067309, 0.625004812581027, + -0.911960460437454, -1.973405398299871, -0.7574758313972065, 1.6261060948310595, -1.0639316504874738, + -0.8549167612511983, -1.6781924250988283, -1.2461164334164385, -1.4396767476893544, 0.588676315376695, + -1.3923513282535058, 1.960640665522404, -1.1216598084556173, 0.29702774865635373, -1.1441990771482464, + -0.9733601129567919, -0.13533496827900127, 1.2875809157665516, -1.004348467034836, 1.501216437295625, + -1.7257128690349832, 1.3038540536955745, -0.23514094567048183, -1.0545846838443325, -1.5628126421353628, + 1.758225843292891, -1.752217343717482, 0.8827182187480176, 1.9633500079396518, 1.4124055174643644, + -1.6009057792139894, -0.124257691420528, -1.6361563376854855, -0.7163857270237415, 0.5991086423774714, + 1.7584781739562239, -1.625063774845441, -0.37572359945414213, -1.9506995916793395, -1.951072499257542, + -0.8315895595505731, -0.7002195813028456, 0.31048848147933406, 0.19118037223773499, -0.836819966187166, + 0.8992259497849764, 0.2769262236848036, -0.8190887502725346, 0.5908744335005158, 0.8309308801063224, + 1.828667115346997, 0.050270920754735826, -1.4675310474376078, 1.6474157726281966, 1.4412270025481906, + -0.071799132580602, 1.2723657902431542, 0.9847345744230269, -1.7967618044169429, -0.38502605748464447, + -0.9911154903546722, -0.9911306363398822, -0.9822148063420846, 1.404660627884402, -1.7223894428279198, + -1.906376932077218, 1.1267093944315762, 0.005733157947431344, 1.5499657009743384, -1.2918427917389055, + 1.8878866260898972, -1.450986879628605, -1.0670858020560647, 1.012435839252528, -1.0904895043171203, + -0.7636238525274122, 1.8215658692720762, 0.802604215057829, 1.6666057955071523, -0.6857256630224935, + -0.5501356674470292, 0.810089459752044, -1.6169276394201413, -1.7364078843810304, -1.1867030927097977, + -0.5172860730134063, -0.8556026745046195, 1.7395171980402528, 0.8977518661224195, -0.715248100272647, + 0.42642471199620147, -0.1359154018671509, -1.017818228497351, -0.00905895348889274, -1.6541703500137865, + 0.5001119548133026, 1.7346626988667904, -1.1674654589916598, -0.735219697011062, -0.1962670855393469, + 0.21602710767932987, 1.634800475118543, -1.614549402431435, -1.86469031751599, -1.0722793725135675, + -1.6751255258166085, -0.34784828468612705, -1.333517864681233, 0.2286247270719235, -0.9686327464308748, + 1.8030391927221219, -1.260653677947687, -1.3291707209191737, -1.7464317874557151, -0.6022055677476486, + 0.4234332187817236, -1.3942184957445614, -0.49460322127068856, 0.846493215661404, -1.3779621390020038, + 1.6170651737934367, -1.5949106936017516, -1.993666650794867, 0.517274246668932, 0.636335391123283, + 0.09728792236496986, 0.21871120388837983, 1.9033150412817141, 0.3761639225094582, -1.7448084623773052, + 1.001122609393847, -0.44673552515686765, -0.6566748795770616, 1.0022029703662332, 0.49152185517814306, + -0.19632140501675632, 0.6593309755584418, -1.42607069406814, -0.2499327686935615, -1.4645970035455589, + -1.8214929827258137, 1.7849263457214155, -0.46930999932929396, 0.930852011498847, 1.4657054090327062, + -0.8598379219960437, 0.21923117934684644, -0.2719980917718665, 1.1382814204088554, -1.4437234293121408, + -0.08437654030814734, 0.9230522551879243, 0.17552818859532504, -1.3982719952892557, 0.7727609217240659, + -1.3512364654797944, 0.4217546307725639, 1.8959084151076748, -1.9009721763514387, -0.2084686501701638, + 1.34507209606525, 1.4812500373319821, -0.25452451515050356, -0.7547650359872655, -1.5901912558006952, + -0.617303822265475, -1.568626713241147, 1.3511951372228292, 0.46680417658438866, -0.9843992974612537, + 0.21141544095185427, 1.9555502838890186, -0.5622924926922526, -0.7074175111692584, 1.6856741408764497, + 1.4329492504371748, -0.6904233032298688, -0.16570616327044707, -0.5819404191250754, 1.5298308400435117, + 1.2873904282242874, 1.75253332340609, -0.3969229369696805, 0.9712496560090953, -0.984449102903409, + 1.845837132921134, -1.23000834955623, 0.25823037305241403, 0.6562595586377551, 1.426434937488695, + 1.2050365141327637, 0.35112023386450986, -1.423781157416867, -1.22442697877245, 0.8857806584751993, + -0.27941851495344316, 0.383573200806417, -1.6546531309712131, -1.0620419037179136, -1.6487673588042684, + -1.1583303816085477, 0.11883432462925647, 1.8623629910270258, 0.7814730455397738, 1.3892839510915138, + -1.2109955091247775, -1.1531820955625154, 1.4249824872214445, 1.878872910977651, 1.96640914460413, + -1.7574195668520565, -0.17080472539130565, -1.2334517024508829, -0.042203796430729135, + -0.5119900340796173, 1.8245076363940829, -1.9504677488809792, -0.15838646478579133, 1.8653261691127963, + 0.9818831615417336, -1.1498049891062543, 1.8177635453973222, 1.0307183336158117, -0.9459693602670747, + -0.6967235469867861, -0.6765683802163993, -0.25117711682611343, 1.7085642032810782, 1.2354232326963057, + 0.16304953424899615, 0.8288006493055686, 0.7426908331423769, -1.8567733398412107, 1.9363614866712426, + -1.1491819532508236, 0.4456745704746172, -1.5200940589370164, -1.2921806881240192, -1.8425344342113679, + 1.385258197718982, -1.4468713069952912, -1.7194028306161009, 0.11112342320496005, -0.46993304971516636, + -0.3532497592673005, -0.20705867891006324, 1.1503266436423507, -0.8615322336734987, 1.4243814245202273, + 0.44605877897824087, 1.4291899741133527, -0.5503339260133986, 1.6845285529242942, -0.11203488733585854, + 0.7838364265704332, -0.3478757678382811, 1.4617786521240204, 0.6797786349237454, -1.4070556983478548, + 1.5368835556999283, 0.7270943692880465, -0.3943204095728534, -1.0159540349760245, 1.3420827647018054, + 1.4944969205796248, -1.3846348875424548, -1.1070204938214987, -1.5623431163391235, -0.7285340046827864, + -1.6146739312377756, 0.7914876850628412, -1.4285275663878165, -0.2102551967847095, -0.6036343290031185, + -0.7863519667198808, -0.5232027574551195, -1.5248951664568793, 0.6403374226115135, -1.3332654904167054, + -1.6013714748847017, 1.5264343369773945, -1.4659567584783701, -0.1255742527269854, -0.05787364846985188, + 1.0630262325298858, -1.0251739144160261, 0.6341878529485214, -0.37708094500223766, 1.5820017651752964, + -0.9754158194342759, 0.021470140570154506, -0.270780715342573, 0.8513662098629382, 1.812913501299759, + -0.5507306480213909, 1.6252304329937193, 1.6166807780171624, -1.0015815783816109, -1.6525598491008084, + -1.7640141719015405, 1.4567526655551655, 1.4314017477260261, 1.1147464550993922, -0.8609069210724725, + -0.8835445478716384, -0.8807052221352922, -1.054638273386189, -0.9307483447707048, 0.6532758793412583, + 0.2251469563444708, -1.4983410691150256, -1.355149970924261, 1.825265403016327, 0.3882557252462826, + -1.2005370275411478, 0.5167632305655108, 1.258505468633845, 0.09615276706317388, -0.11253109876707157, + 1.301050271249565, -0.5926127701907813, 1.6730534476647492, 1.2040207312610214, -1.5220497985479415, + 1.7064305519329883, -0.7793546956972763, 1.0964366887160475, -0.8642894018251441, -1.9411871186407836, + 0.074716954760353, -0.14369204309870298, 1.8393941069756865, -1.3769308839327685, -0.30862200557660646, + -1.993793823331563, 0.3709006233167482, -1.6557247604820402, 0.32053951400820324, 1.418554947267494, + -0.14801920801346657, 0.25882446183466357, -0.30227778350472967, -1.4281993644549162, -1.2907922764091362, + 0.9110864171884971, 0.613974184551024, 0.6697305289032087, 0.8489421527088039, 1.498148243683703, + -1.4269397350154973, 0.6132189565263042, 1.8741137765083877, -0.05705777446194382, -0.8855796810622429, + -0.8656995527854097, -0.5082467483431357, 0.4332387677470342, -0.7541429782381526, 1.305940642158534, + 1.3554774725998202, -1.2111195490457929, -1.4381676776657422, -0.5207599119467634, -1.8228914923142483, + -0.1877702958456453, 0.36230894851909135, 0.17851993376959374, 1.7927379289246463, 1.1368406732884377, + -0.9950802434664396, 1.6043789056058433, 0.6484661390901874, 0.933455445947418, 1.0965484754420745, + -0.6939481831648528, 1.2328397545284568, 0.9378872541025238, 0.34445900641106775, 0.1278365294249939, + -1.0258774152176224, 1.0892898808290514, -1.067964672512037, 0.44689053769430487, 1.9413539674195102, + 0.9528679183933839, 1.7587111895733223, 0.5576643641512362, -1.9897364013123235, 0.47036230007388813, + -1.7897636522061902, 1.5010421043239726, -0.6901446605239645, 0.45393761765945406, 0.2829324542693943, + 0.7192447802434581, -1.1331672455879191, 0.17337802502719768, 0.49963004068736083, -0.7815511951172738, + 0.9636628323801011, -0.8323427533238403, 1.2095900671544157, 1.934951536397083, 0.5200093238971917, + 1.158643992086569, -1.5891421117437572, 1.1934443042649656, 0.8276249819736909, 1.4212639972350827, + -0.8419641775597322, 1.2661381994885206, 0.06158022749540493, 0.46246628371777465, 0.7573951522202043, + 0.7619102543908491, -1.9354772481952196, -0.08541307148121735, 0.32286401465719994, 1.170684883004844, + 0.7749781307966064, 0.07936620410936168, 0.2315148010116328, 1.263046672689227, 1.031153500041193, + -1.4759863901212418, 0.9873606549190814, 0.7777062346878534, -1.7228292836209809, -1.5782061917261343, + -1.9224252408205036, 1.6523756827872802, 0.2993123495389698, -1.0626357773248643, 1.0197875082835361, + 1.9808878982391658, -1.7335070696659765, -1.3148099108168987, -1.9494880550804297, 0.17268926636423831, + 1.242098464223531, 1.1180045679110648, -0.6736825181756254, 1.0154763824746373, -0.19957976202073535, + 1.1193883202782464, -1.1478784045364039, -1.1330343367225462, -0.8689855158445736, 1.0116971724295913, + 1.2506562934116516, 0.9967792918784815, -1.0610583651118048, -0.532297674201299, -0.7150119770938783, + -1.2232688695213954, -0.4743547399060102, -0.46434726873258025, 1.0115464987548544, -1.0177528333698103, + 1.733606846210808, -0.8155729251191701, 0.5232276995816267, 1.7914758177269183, 1.0157926433716051, + -1.8351739772356694, 1.2012667794467289, 0.8545667843161366, 1.294851140223634, -1.3948286719024399, + -1.1466399375859053, 1.3041194739119701, 0.7377935849631942, -0.7995103884432107, 1.6159710967964172, + 1.1620542442509896, 0.0055755030499886615, -1.0334307489311154, -1.5841651529973335, 1.4111075346503084, + -1.2009959445559417, 1.9720845657548844, -0.8035695524416591, -0.14331494930025102, -1.4264826197967428, + 1.7586100751820988, -0.739427319576853, 1.0666567765029669, -1.5916344843881065, -1.4350385771936356, + 0.42668754843352463, -1.6484548280031257, -0.37728075139212613, -0.7687371980269919, -0.8179781982796062, + 0.465506361112749, 1.3840790073759015, 1.0880219954714985, -0.7036856119504717, -0.6963478247947235, + 1.0979632285661367, -1.3810178288903288, 0.8156134298411253, 0.10032276394202011, -1.6008032936016479, + 0.18932441617959217, 1.3551150619453578, 1.3534267176398442, -1.1276635081675348, 0.6608005967002919, + 0.793182461268664, 1.398769261405878, -1.2123058565244103, -0.08803245110073288, -0.2893447181014084, + -0.9972961711861705, 1.5618332897004663, 1.1591927593779072, 0.511047619279922, -1.970138349713964, + -1.3628804504805752, 0.2782295809685751, 0.30358322411230354, 1.9514398542744873, 0.25960063763317454, + 0.4976234205537926, 0.012047558099780531, 1.79534431915478, -1.7723391117592922, -1.9992623394682916, + -0.4524505481322034, 1.3804610881366495, 1.1587664810582172, 0.5111716739430667, -1.6928217537440542, + -1.0278751605013383, 0.20893412968684988, 0.5871739815665329, -0.8412950581167742, -0.4077765748738731, + 1.7498754266646204, 0.8583271186920243, 0.5762482317954367, -1.8599099610537024, -0.19242912582490845, + 1.2512291284228754, -0.8441763329152305, 0.26980735485206075, 1.4044456507515894, -0.8516268695811835, + 1.4493090144656193, -1.3915783403234894, 0.35557624127716814, 0.17226516309619733, -0.5021504124493701, + -0.766188811190383, 1.1332244159180078, 0.011135590774230764, 1.8851307343362258, 0.9148262788018782, + -0.8299956158151707, -1.6057691996197043, 0.5678238711359924, 1.8008767630518667, -0.586304639586193, + 0.47029839294118503, -1.463460016599707, 0.20856103503853962, 1.2545845494965118, -1.0729668619560213, + -0.14947337388785709, -0.20035875199434283, -0.07202935566940027, 0.05533721254453372, 1.3677731442776313, + 1.7011893855187177, -1.7202195328563636, 1.9488792451860384, 1.3096386167232117, -0.5132153326702822, + 0.5616165083457831, 0.4157359121447879, 1.8006481124839855, 0.230442477572935, 1.1686013774607265, + -1.7879670674147912, -0.842370723742838, -0.7927388944332199, -0.9586442598316518, -1.54708954114047, + -1.2956507442445577, -1.4031204732951874, -1.6120562181795481, 0.6283505387959369, 1.9223686678649798, + 0.12298814371626143, 1.6278360280654836, -1.6223557147160461, 0.43054457669015456, -1.7908842288361821, + 0.5775385836169233, -0.15097004219870414, -1.2290692851647318, -0.620782793926316, 1.1062043604891931, + -1.9746433547898716, 0.6382174626600765, 1.749692571795931, -1.8339775549081967, -1.464038954875173, + -0.9639795224425223, -0.990228592162139, -0.9403728223487793, -1.6685943188697578, 0.07041255085387288, + 1.5308882897823413, -0.47253846241489494, 0.8106189739961147, -0.2270261582976314, 1.8799983165866934, + -1.2472611795739992, -0.15798247303785384, 1.1021702350596438, 1.554345756888078, -0.555551779439396, + -0.2519441527853594, 1.6402787480890648, 1.0543147407284197, 1.6335920152443828, 1.6300453015691483, + -1.0481818985064661, -0.7279789339473091, 0.4888000242214119, 0.48732772667936697, 0.8584068837243475, + -0.2507103566018847, -0.4408909179300613, 1.1233796101976283, -0.4632288097116861, 0.2476082637987398, + 1.0440154957174563, -0.8869375992382329, -1.4091245527531013, 1.1244796013617524, 0.6892639663640718, + -0.022508581368523295, -0.1947662213695871, -0.03308275174489772, -1.5224685354120124, -1.749461912732114, + -0.8082369514527556, -0.25696601764524907, -0.29739460489203395, -0.7761972897539025, 0.11904461166545932, + 0.12301125764257481, -0.42555313462267197, 1.193050472577613, 0.41286096527825666, 0.19893639265378305, + 0.8311360551038502, 0.07310473388978878, 0.6685202630904961, 1.3372821732384246, -0.12066792119153114, + 0.05609759368011602, 1.0451437108149522, 1.662035784060775, -1.1922702285472315, 0.6416834232307735, + 0.6055133359661493, 0.7942484561725927, 1.7276358502409623, -1.2864330076965746, -1.0918326505522646, + 1.1998891150473536, -0.28628797182602295, -0.5716447940943157, -0.5689380516868878, -0.12795378416769676, + -0.6791193619117468, 0.10038501986448622, 1.7063771482852008, 1.194191767284221, 0.6647694331145066, + -1.7348358815181681, 0.2817755499501846, -1.8525340342841066, 0.9851618492112006, -1.1059877587552673, + -1.5295669583116638, -1.0505227522752554, -0.9487207502273058, 0.809796237150997, -1.8498482833968302, + 0.48019511430536, 1.1226579267420975, 0.742760641350614, 0.29074715598422696, 1.3377685110252564, + 1.0867564725571857, 0.18478374884500237, -0.3164295444608616, 1.622419371685548, -1.1918941739540632, + -1.85979992788012, -0.12624583269268985, -0.7280639736725645, 1.7648123921702625, -0.2906781194024406, + -0.09262137204418774, 0.4138200526064937, 0.04327935477419942, 0.9797051661471672, 1.6815036980821159, + 1.1825812431197686, -0.4448889424945657, -1.7444130322963716, -1.2294019413222204, 0.050868336661046065, + -0.28346687593229625, 1.76715189128082, 1.611534349398629, -1.1862000782338198, -0.18760651204489776, + -0.6355147050302099, -1.0693954841485978, -1.3343928935517813, -0.027843745181175272, 1.8949354550001445, + 1.1947313752564401, -1.705187723165774, 0.7263276294399761, -1.7205254544798727, 1.3224746522778261, + 1.1708741480141782, -0.521723626523988, 0.14495711998638772, 0.1539508663177509, 1.3474086945015884, + 1.629428728782094, -0.4727414254013107, 0.6353417064465656, -0.03451981234038648, -0.5836884681507799, + -0.20106395276242583, -1.7734283674131142, 0.32803960400104515, 1.9133097530382273, -0.24540943135655446, + 0.5453897283838289, -0.5427539051784143, -0.6173702928258038, 0.9933437009783477, -0.3068382522087365, + -0.9721176288546554, -0.2527049737425271, -1.273295878249181, 1.0233037474978097, -0.3711521481099638, + 0.5595202642561681, 0.8557760663482039, -1.8239035041840888, 0.9147544038362989, 1.7978278449861724, + -1.4194290997912544, 1.963345131841125, 0.23090855457281645, -1.3144424697360408, -0.1798082266955383, + -1.2650654602972473, -0.18679991460173095, -0.17629662209988428, -0.18349530363472422, + -0.5475428774726536, -1.04729696306776, -0.6411208082297168, -0.7308839648617003, -1.3455838296857312, + 1.3597877917216312, -1.8340993142240434, -0.6466367039451653, -1.8498009040527492, 1.50126492806735, + -0.8473870152631546, -0.566754366425382, -0.11291722444101016, -1.204932180016593, -1.3014629360740422, + 0.8120263473720843, -1.8930799484123453, 0.5895700763514027, 1.7580882681692787, 1.477127335430259, + 1.2047061381469248, 0.16485767274321272, 0.23788298792881424, 1.2386497309565767, -0.05281444305869076, + 0.17334853252953497, -0.9480401485516037, -0.12016072866190708, -1.5973543471706693, 0.9977273313654775, + -1.859593065673394, -1.22159503592711, 1.1674399144996856, 1.2941842856062413, -1.1135548933000354, + -0.9788839012477277, 1.5718286586532253, -0.1759982975417227, 1.1031262106075763, -0.8778538891016368, + -1.0912009563680698, -1.4342020269338933, 0.7224191131816049, -1.2061468497546022, 1.5502197738160985, + -0.4251181088474105, -0.3206510906592923, -0.20288873333289725, -0.3560455119064194, -1.7508196425056557, + 1.5986301909131457, -1.4438601542872123, -1.7892166215086185, 0.7616375057554672, 1.6373087559633968, + 0.384206177448279, 0.4567136418012572, 1.2972172920819673, -1.1093690558377691, 1.6750979445730616, + 0.90908880550879, 1.2770950805973325, -0.30072973792624325, -0.30858954557714835, -1.9794583286836787, + 1.4463537661478734, 1.0183415951718624, 0.19309738632155593, -1.9449394663020856, 0.7108312298345272, + -1.131148529447878, 1.5259401710667637, -1.2736225271446795, 1.449852979043179, -0.5704566653927854, + -0.3074713127723667, 1.3993673001014262, 1.7718101827800963, -1.4492416161385497, 1.204820691151994, + 0.537241600964693, -1.5810854626566924, -1.5966679426112087, -0.6946214128685, 0.7401275132249641, + -1.7160959560539784, 0.4050021853660404, -1.2105835684442203, 1.7944565918560151, -0.6818406209243566, + -1.2359644949792958, 0.49907683448423157, -0.8405824207686488, 0.4689270476935219, -1.4699087331797918, + 1.43264169315706, 0.10499119180317251, -1.9830520821430992, -1.5927469176409472, -0.7632048947095695, + 1.303303968850459, 0.5773554905161333, 0.6761130632322967, 0.9023788989770569, 1.8960847479504084, + 0.8846144527507596, 0.9891128512774987, -0.7137249307539442, 1.267181508288493, -1.5113665535004523, + -0.17564834166799148, 0.9032525164747707, -0.25795858643541525, -0.2153155122989885, 0.14650372443070392, + -0.9984704344431465, 0.19261182317116177, -1.568804815745514, -1.8496400745041237, 0.7219284702138937, + 0.47816525621959105, -0.3273800096758981, -1.3390312793543542, -1.9838474983553418, 1.4066433632074071, + 1.9258390074704312, 0.4520281509654822, -0.30119846185025256, -1.8086624212334463, 0.851460380433025, + -0.4149432793401253, 1.4970655678791776, -0.9139304175330043, -1.1721517169571731, -1.9882366923130537, + 0.20630155701555886, 0.0891351853539204, -0.18485046053672338, -0.5253430902979694, 1.1136150007281822, + -1.072256739674419, 0.5677711226994742, -1.6986682182236068, -1.373143853609375, 0.5391705517446521, + 1.615483488858379, -0.18222418110590155, 1.5270125615115786, 1.4186450525284275, 0.6856859039097802, + 0.5948037341597869, -1.0097732940745248, -0.7260016082299225, -0.1705798585617213, -1.4460592122059417, + 0.2804912469966476, -0.0574570618149588, -1.4226509159038505, -1.3490817825559507, 0.7561887451342573, + -1.0315310280500372, 0.7865868802852711, -1.7955739447423928, -0.20476732094967787, 0.6532859024525468, + -1.398626809307176, 0.31416850473475755, -1.0474173751356446, 0.049027524534579925, 1.335442264825483, + 0.5839485880852768, 0.8416818491436544, 0.7729008830376998, 1.7957935152184445, -0.20047560204525272, + 1.9653799460331678, -0.756998178675067, -0.12357101901807699, -1.4272827743751613, 0.7149414745051672, + 1.4783565252719182, -1.2368177109511205, -1.4571248051607144, -0.7948678149157731, -0.6295946982419727, + -0.022851757488315805, -0.07947620035768654, -1.3106359681202076, 0.1591438592300909, -1.4970586188027868, + 1.3181273904865316, -1.508591213967403, 0.5722257787143228, 0.774539967054146, -0.5579675263215638, + -0.801690277809052, -0.8966439545169163, -0.2168181087774288, 1.8549965661558616, 0.7870136331314779, + 1.0426166176054243, 1.2052992540989846, -0.6116512580549873, -1.7800528483131748, 0.6162047118916432, + -1.1406795391578877, -1.3126212462178328, -0.1255252753148266, -0.048214851156274996, -1.7513823416941525, + -1.9966724157135571, 1.468282137353885, 1.1596808879879097, 1.848952713577705, 1.9276331797246486, + -0.6082295997412146, 1.9590194651252002, -0.6705403599782791, 0.8982591946264264, -1.8582005994721253, + -0.6224103416017206, 1.3118474535601639, 1.927285880838153, -1.3435831019941835, -0.02035775798119932, + -0.258091815197548, 1.5685276792778557, -1.7504336743073416, -1.3270808448193447, 1.9609655615175043, + -0.5002114597187894, -0.8302889305621663, 0.6662682285835677, -1.3588868202703237, -1.4263374077454936, + -1.117653746556062, 1.6959423725848142, -0.3368698386266633, 0.6329184444264122, 1.4360518922995382, + 0.2209792086889042, -1.7826312330601093, -1.9055378329489479, 1.8363537758423742, 1.8612237061845747, + -1.163857834211714, 0.38823573714522475, 0.9933133475252713, -1.769852560129741, 0.6303163049709841, + 1.6352278260339865, -1.4220707937174062, 0.4996182092181929, 0.1748538915264719, -1.389807604688972, + -0.5041547053983226, -0.7755917479953034, 0.33822942573796055, -1.3957767536429841, -0.16066323457963172, + 1.8426173458683097, 1.0912529333551886, 0.04454407634104118, 1.4585397734066836, 1.314915917164475, + 1.0930141444320949, -0.9720567164640972, -0.5831452038265033, 0.8082335756515109, 0.4358913655339238, + 0.8310387682873994, -0.8242800720840835, -0.47497624245619896, 0.000058968841639917, -1.6746583184349388, + 0.2586283765233146, -0.03952361428650608, 1.9572062803747263, -1.364317103129661, 0.16484595584710782, + 0.6889848970954304, 0.33625779127527444, 0.28142293472509294, -1.5510992496482494, 1.6785313595707674, + 0.4921495479711657, -0.42294403727168906, -0.10192465238332815, 1.583070264702826, 0.6464143795128816, + 0.7706704090619576, -0.45316577898360944, -0.6156337052461307, 0.2949317256431403, -1.1153946167003506, + 0.23143632095143918, -1.187495465719234, -0.754948635807529, -1.090644714217727, 0.8562387289761135, + 1.4209567719285578, -1.867698005779011, 1.3320884849513037, 0.5619380450950349, 1.8886416226851166, + -1.7314027359692306, -1.0362482885730966, 0.9807231768105664, -1.3689591083054822, 1.5694772951886131, + 0.4400722090716478, 0.9539178709143741, 0.4832872148319014, 0.23471769113792984, 1.9643745055943542, + 1.1325801513292664, -0.43752654225713705, 0.4538975778222154, -1.6157155513403065, -0.961125955159364, + -1.1751535270699955, 1.1277536127856669, 0.11594556933087752, -1.9276503102738447, -1.5774089828974898, + -1.0029301039964427, 0.4455245589428616, -0.4739643281334569, -0.8513671370845639, -1.0336436816615233, + 0.6626347865920605, 0.7885413873550009, -0.013439463013608766, 0.6488139123172507, 1.4291110296253855, + -1.9761431450732667, -0.5012957954679527, -1.1585910698227027, -1.1021436093929422, 0.036919383239228054, + 1.8089329170710071, -0.005231354013648826, -1.2082234644042886, -0.3456887578591781, -0.8017405353429492, + -0.5345375675659492, 0.8534420279659507, 1.7447905633469585, 0.43817127727920724, -1.8499205219957702, + 1.4797731845522186, -1.5443888715914138, 1.0131225647292235, 0.34701885989022063, -0.41455116200553377, + 0.40209313291363724, 0.11900781713274, -0.9935386808363758, -1.8322340002578068, 0.9811839836103111, + 1.502154507354498, 0.8891949169357574, 0.899071159318308, 1.752337905147142, -0.04599799842734953, + -1.6347681983052045, -0.5522741247690393, 1.505215487771519, 0.8504281241898015, 1.8693941265525265, + -1.1512863792577441, -1.8748160415118837, -1.879939448107617, 0.9353149913960506, -1.077101932112896, + 1.2322050595012843, 1.2672982902122678, 0.9384368132472858, 1.7274119921052788, -0.9601726232935137, + 0.19420343716687505, -0.7830049935581602, -1.9099470296794694, -1.213386784368356, 1.6800660417837605, + -0.9282638481321719, 0.5088239004955142, -0.5528513330962577, -0.4235136044745138, 1.6316021980530238, + -0.3087654696690505, -0.10527992793999896, -1.4364007982935343, -0.4455364976497762, -1.3433044303003099, + 0.6517505064656408, 1.6050250028051813, 1.6490276577492855, 1.9140119353414144, -0.7684496098140174, + -1.01738188731548, -1.1250647161193914, -1.6586222112755102, 1.1599068196677091, 0.795751774794466, + 0.5733174614685748, -1.3655937932875277, -1.507254849973065, -0.2831083653801638, 1.3241227396573514, + 1.574957068221127, -0.31194765030973937, 0.4008126582755933, 0.43635579619776443, -1.3214048572867325, + -0.8447194221435215, -1.1526249262582748, 0.7073544609451421, 1.17078844004073, 0.2425026449956018, + -1.7518561882120753, 1.6591407848437605, 0.06616038448738504, -1.928680221520497, -1.0504809684365677, + -1.0974712342176778, 1.6344494477175475, -0.4129201382527832, -1.7111789594333953, -1.6070549808753904, + -1.2456702084965565, -0.012663680475193395, -1.1305840149083926, 0.734392120651302, 0.18651679771884844, + -0.22974141381305735, 1.9415149817194726, 1.9280078232850126, -1.2072428658632042, -0.14782869942839927, + -1.6523593328098034, 0.4844141001145905, 1.0492278525622805, 0.5924539450553175, 0.848097235977705, + 1.8881210898619676, 0.20004070245023797, 1.4305799893712425, -0.9082660328332564, -0.14268688754147085, + -1.1201061991671262, 1.1399839045134712, -1.5579448101377515, 0.5516078322124933, 0.3365679579810825, + -0.636402425334972, 0.7364990374614768, -1.2657328423109204, 1.4084870144147636, -0.5538274490613713, + 0.43684201943536305, -0.706532199493215, -1.7678543182116737, 0.5086879667154935, -0.8888826793267235, + -1.0510640830474856, -1.775013227468511, -0.7345226367397419, 0.9474796694127203, -0.649964939391042, + -1.5189099534245498, -1.9260526549789567, -0.457330394781688, -0.9340352374682741, -1.3868164983748459, + 0.553888202560878, 0.36818698767921365, -0.9382183717778192, -1.0829596839250488, 1.3646658325042367, + -1.3240476722940633, 1.9816923192707012, 0.5300123477141927, 0.0790085088366057, -0.4455760575475125, + 0.48463653297167486, -0.7788483158746828, 0.8253771773416885, -0.6823431576948558, 1.7776737534704772, + 0.5497713214586923, 1.4452464852137838, 0.23037431004796094, 0.31188142524786766, -1.7543267850797761, + 0.6063452856820941, 1.207122989395999, -1.926907332363795, 0.45038239145265724, 1.0988284911574286, + -1.6007436457047142, -0.4728890687538678, 0.3195037474199047, 0.8855124961325762, -0.2555993730577626, + 1.8813620496087493, 1.8900177166377103, 0.09592367474164032, 1.8974568987778628, -1.1058972953708501, + -1.1512017435907103, 0.40201549011430693, -0.1831060132280804, 0.22245091899613723, -1.1866541831479385, + -0.5451040730392975, 0.9199451579519691, -0.42060255461704177, -1.3791747236441925, -0.3024448490768936, + -1.6611455107283604, -0.3106541240888907, -0.9498356682876157, 1.769660410309836, 1.598216213741022, + 0.4623205859503434, 0.03664778458072249, -0.6655252973523123, -0.7325818423601653, -1.591871681771024, + 0.9451427301297981, -0.8203468674560934, -1.5069504221011005, 0.7243170638862324, 1.1839749702725175, + -0.7128348341329511, -1.5076965090949397, -1.59865172895221, 0.08910680490749368, 1.5717880586471278, + 1.8951504684652152, 0.42550207471805646, -0.128409822054123, 0.9896766313315162, 0.34808462644009275, + -1.7082990487472571, -1.0459982270685435, -1.1132292311691874, -0.3022325459842996, -1.7274216318536348, + -0.11775921716410043, 0.7403577685290532, 1.188824227090608, 1.387282721223393, 1.0688709331799577, + -0.6395615121564866, 0.8142138261269114, -1.4467483751545576, 0.8996177593321626, -1.8193866881462766, + 0.08924208518874632, -1.405297919708996, 0.31754790231458685, 0.9823851818369507, -0.49590144528424585, + 1.4194064220328588, 0.9729299634967079, -0.46170090347918613, -1.634203532024186, -1.3139980454214522, + -1.8469876250843802, 0.710926322864931, -0.4029599381569682, 1.7246833539931403, -1.4088169680807, + -1.9165388068372708, 0.21804317359714798, 0.3898186987610348, -1.6118063668363405, 0.28583673086194583, + 1.7015683175211391, -1.2836168642070582, 0.8463494611619371, -1.1625839799245696, 0.2640138032690018, + 0.5041551687310717, -1.755824925370514, -1.177748867346489, -0.3829444120449246, 0.8360805034202388, + 0.05022254918868896, 0.4276469609032256, -0.8235567451730139, -1.94062145827791, 0.35097020666890355, + -0.6636358495150212, -1.2587452298002964, -1.6575362996910616, 1.060467707494971, 0.9661831305087887, + -1.77149852066976, 0.18844425542251564, 0.0897431507375952, -0.8281105706620897, 1.5632959207508623, + 0.07141359825195703, -1.0844701043735414, -1.905443802421968, -0.27588161311845294, -0.40342423607415956, + -0.34332304727825136, 0.0176022295550462, 1.94359831375926, 0.09702777017089215, -0.11695098349068722, + -1.2810374187149858, 0.37597456306160204, 1.7631374725308877, -0.7830266259108773, -0.5605784036815882, + -0.4409773606270875, -0.49636250754717803, -1.549108447216227, 0.6261185797820117, -1.260881611110821, + 1.691411217905281, 0.899655093658585, -1.0875528162122174, -0.7120948701980732, 1.8214705523154269, + 1.3010380076854968, -0.4492643980075144, 0.9914465230608682, 1.7590027691290615, 0.8514661670055963, + 1.5263431803492642, -1.7260779024351258, -0.14589666108296218, 0.18011804793376918, 0.7175880982696512, + -1.0399388762140145, -1.0480376846250712, 1.5656146512942648, 0.4435540525930799, -0.4175857955829816, + -1.8218436496980575, -0.9346408060646185, 1.40089015453285, 1.6926667426168764, 0.4187248632147291, + -0.6755275086264145, -0.7011229448771363, -0.9528087286614646, 0.0730922604589237, 0.6252467328216804, + -0.5573518555770702, 0.9864121888624755, 0.6646486706800783, -0.8405364163020792, -1.0505815213688878, + 0.8989991238262265, -0.5022516947851985, -1.784806766373471, 1.2637708002659025, 0.5065772818030325, + 0.9973024415787677, 0.08348064671549338, -1.757630249437522, -0.2016005631449005, -1.0360120513086803, + 1.786128872822113, -1.2720919225213843, 1.3430514506545927, -1.0762325516117865, -1.0578995596104255, + -0.47242526972271204, 0.05539697038437996, -1.08558757453563, -1.404337586710036, -0.059247790489052043, + 1.3998034069978171, 1.6067367856721155, -0.5185826391883994, -0.051896682542695416, 0.11112005542023429, + -1.2231398633939348, 0.2886372299242277, 0.6564469519248641, 1.67404118804063, -1.538487261793886, + 0.18551945213331145, -1.1342837192256061, -0.3318725405647953, -1.4531152273595227, -0.6934713826285721, + -0.24436235286417052, 0.6776292484438171, 0.8871814678850702, 0.41826798275898014, 1.0161513742931785, + -0.13947907673300097, -0.7736759327375049, -0.43981678279829683, 1.635191807530191, 1.3044401854805878, + -0.3097446711021723, 1.8125726195847056, -0.26127912212234694, 0.8564630403854094, -1.519521818793156, + -0.5727391479884938, 1.7015469847109976, 0.663240965083145, 0.31064120951508656, 1.4030451184981052, + 0.3325065959732836, -0.7178902057747756, 0.6090652378284442, 0.8426138183122633, -0.580146652112278, + -0.6076938097212707, -1.6599273271373782, 0.29960912457791444, -1.6741835731853065, 1.5428301790607195, + 0.8970548194971704, 1.6066845600081736, 0.5404165757730146, -1.9537941867764292, -1.5234595572340748, + 0.5293735217702951, -0.64620260665742, 1.8818640992235771, -1.7237764606754276, -0.8040024538741264, + 0.0642546885214017, 1.4395299343641659, -1.462587128675942, 0.011882540823848764, 0.12033421748154716, + 0.5458210215408705, -1.5141295301316422, 1.8809343680577308, -1.8801856621666753, -1.901376259472575, + -1.4374202976060095, 0.8473513507453765, -0.896351895119154, 0.457751001832321, 1.876657552919962, + 1.267733433184599, -0.30894648094866195, -1.8178016120669414, -0.7711776919446018, 0.29038564786361576, + 1.6396189720781438, 0.9597929848181161, -0.34227788522140834, -1.4450087753233527, -1.7068508353679626, + 0.8426935759536303, 0.7173810205823674, -0.6580236891322881, 0.8663322021812405, -0.38112472550089915, + -0.3331447946260786, 1.8551673806318556, -0.6731525492100126, -0.009001319785657103, 0.41039833755685784, + 0.025091358839535616, -0.49823213412251555, -0.9827448714264726, 1.1077851800046377, 0.5740585983905078, + -0.7235926614954762, 0.5059901875180826, -0.850177898505664, -0.05453987892121592, 0.8633840127545733, + -0.3153969644106205, -1.3028092681229868, 0.7523083527030074, 1.413775575813558, 1.0697458650110754, + -1.1839403319780022, 1.5167022074836893, -0.36486781099211996, -1.7835462010879564, -0.6061285803342944, + 1.9969466536022722, 1.9531672204642883, 0.7967381388222403, 1.0934589095880973, 1.6405590176012312, + 0.3501113568054244, -0.8786338692108497, -0.0545508019996932, 1.4464849975584952, -1.8853956921596513, + 0.25983013847132774, 0.8440414107184964, -1.8818826057620326, -0.22674619971532906, -0.35951513414106007, + 1.6757192364875237, -0.15819503713874195, 1.6691357866915144, 1.8534771980207108, -1.8297709996602967, + -1.6514392305036534, -0.2343385561012088, -1.820925987823773, -1.2556074451315578, 1.0621490016055715, + 1.2109203955756476, -1.1500152481919903, 1.0466452723330733, -0.5833431814034746, -1.0817348313127493, + -0.40295742758029984, 1.0986483593098368, -0.42704879342020696, 0.7065668399686658, -1.7278295290305223, + -0.7183051438456021, 0.15944089245801774, -1.0276995188812146, -1.6398474518024653, -1.2318071869917695, + 0.23333723988807886, 0.4626060172442088, 1.2255286520425894, 0.9652309086097803, -1.5254810192280601, + -0.6683416541099767, -1.9000628944332894, -0.7244291249780632, 1.0347731523086239, -0.9009629081875952, + -0.11734057204667625, -0.6698286923748489, -0.9472592207823913, -1.8286232413501864, 0.2898215560382935, + -1.422921306299517, 0.6696091233175077, -1.014229444534946, 0.7139087775492996, -0.23241859018004174, + 0.9620272535910068, 1.1497473812544134, -0.8640235723632799, -1.3121563547519584, -1.0396316763878488, + 1.7769188425035614, -1.014039070903868, 0.4489833453895331, 1.0763456360970807, 0.6229672422660295, + -1.7464692834364435, -0.3663893352922596, -0.8769410861076103, 1.0608706111705635, -0.9262810932490861, + -0.5468726859924029, 1.7966956706569919, -1.0663859467239112, 0.7378193848221777, 1.068192632208282, + 1.3312476842417755, -1.5902814653913628, -1.0131078061920382, -0.6235296781383814, 0.37504339068020176, + 0.9126508132330242, 0.3999532385546267, 0.6552059838941, -1.6053942866342332, 1.7900258342291853, + 0.08171912833062756, 1.6137883979635745, 1.3466843147948442, 1.8505801094553158, 1.3728813966930913, + -0.4473279660140852, -0.20009909626620814, 1.4067472413245437, 0.36658966226851764, 1.4566800897303072, + -0.11633958899045194, -1.9458410018060368, 0.5651869174802018, -0.9885077925334622, 0.24385043055374833, + -0.4407908079663816, -1.7015252126482139, 1.5396273477916198, -0.9801103159055833, 0.9331708410017399, + 0.058036076446482454, 0.29277070369481883, 1.6896333641554682, -0.2872886303585469, 0.2981100430160728, + 0.1670720357805502, -1.6828245857476496, -1.4681960401028125, -0.9933436100210251, -1.827639383468739, + 0.08433714147463611, -1.1318904274562795, 0.9840669856671846, 0.8204547128989219, 0.5959008566248984, + -0.22424536381303728, 1.765380932910376, 1.050492887173749, 0.8249285352430338, 1.5823516671950122, + -1.4695844512182843, 1.4009128159343485, 1.0886951647082785, -0.4963319371911856, -1.8633848779197413, + 0.660465126445569, -1.2319891082878298, 1.6547000157065659, -1.6403428022350113, -1.2308283749125177, + 0.9142339764828238, 0.18691349086990705, -1.148271069003111, -1.266859733272054, -1.4482873768560758, + 1.6888579757850914, 1.5392518897570104, 0.41499451567073464, -1.0517290742419663, -0.9856143466540894, + 0.704611691207357, -0.27871441123648655, 0.445828139270918, -0.8125969294930622, 1.521716695437079, + 0.5657668386735519, -1.813374372841099, 1.0076529676525672, -0.5864288471977783, 0.5855480422270194, + -1.8330974772064481, 0.9782157266479414, -1.6230556142249775, 0.5265126362718373, -1.6878701852107563, + -0.3955226747487526, -1.3888929741627605, 0.2905034183357449, 1.0489208524387843, -0.3118857187498678, + -0.6289506096761981, -0.05735383950307149, 1.8668941791416147, 0.8898345005884769, 1.7147482078759548, + -0.12387314928310289, 0.2298818139402634, 1.9294076224252024, -0.43580099597679656, 1.7512542893273144, + -1.258214124547644, 0.9779750741630782, -0.2566261319632144, -1.9813300069235993, -1.3498734101224414, + 0.7506344777083953, 1.8867470646651894, -1.918953273635191, 1.7429571494233906, 0.7638060343526085, + -0.44782770384121484, -1.1300950570142518, -1.4753506380821149 + ], + "dims": [2560], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.027313262224197388, -0.005701353773474693, 0.1959753781557083, 0.10011828690767288, -3.6098804473876953, + 0.00864929985255003, -0.011981655843555927, -0.11036527156829834, 0.6647213101387024, 0.276733934879303, + -0.6354819536209106, -0.014075735583901405, -0.059462033212184906, -0.3388662040233612, + -0.017422985285520554, -0.043299876153469086, -0.16756349802017212, -0.07582926005125046, + -0.16514767706394196, 2.9962074756622314, -2.600733757019043, 0.04413439333438873, 0.07896167039871216, + 1.1207873821258545, -0.032255738973617554, -0.09964963793754578, -2.1782073974609375, + -0.01814177632331848, 0.08586198836565018, 0.380964457988739, -0.01918521337211132, 0.006902141962200403, + 0.0669674500823021, -0.09234043955802917, -1.0496017932891846, 0.020094068720936775, -0.11474193632602692, + -0.056350305676460266, -2.2275612354278564, -2.648808240890503, -0.017779357731342316, + -0.2514607608318329, 0.008559616282582283, 0.010673644952476025, -0.32376542687416077, + -0.16903237998485565, 0.026010606437921524, -2.163571357727051, 0.35461699962615967, 1.5194188356399536, + -1.1094666719436646, -0.012471643276512623, 0.0767873078584671, -0.21644049882888794, + -0.043257202953100204, 0.001341399853117764, -0.1367240697145462, 0.005313852336257696, 2.144134759902954, + 0.11904949694871902, -0.26428619027137756, 0.014375614002346992, 0.06913577765226364, -4.196413516998291, + -5.172718524932861, 0.06162356957793236, -0.0010976337362080812, 0.21020400524139404, 4.567638397216797, + -0.059758733958005905, 5.990215301513672, 0.19405193626880646, 0.003011247143149376, -0.1036064475774765, + -0.016247211024165154, -0.12790939211845398, -0.08561908453702927, 0.25051021575927734, + 0.07514326274394989, -0.6767844557762146, 1.5661166906356812, -4.326471328735352, 0.07481537014245987, + -0.7969828248023987, -0.45468205213546753, 0.21233250200748444, 3.420551061630249, -0.0759267508983612, + 0.16086462140083313, -0.3939729928970337, 1.3957020044326782, -0.2972649931907654, -0.31666669249534607, + 0.35118427872657776, 0.3117898404598236, 0.088602215051651, 0.17165301740169525, -0.8542330265045166, + -0.06893759965896606, 0.08126193284988403, 0.02327258512377739, -0.1314769983291626, 0.035079699009656906, + 1.2096712589263916, 0.9461245536804199, -6.337772846221924, 5.575413703918457, -3.9876515865325928, + 0.01430205162614584, -2.093717098236084, -0.056584782898426056, 0.05612698942422867, -0.01935030147433281, + 0.0010159225203096867, -0.38109132647514343, -0.000587565591558814, 0.12273997068405151, + -1.4854758977890015, 0.016024703159928322, -0.05192752555012703, -0.257480651140213, -4.023406982421875, + 0.03150588274002075, -5.065948486328125, -0.07601942121982574, -0.04482676833868027, -0.01937261037528515, + -6.9667582511901855, -0.05368780344724655, 0.6142992377281189, 0.3128206431865692, -0.3862888216972351, + 0.053061991930007935, -0.24360240995883942, -0.018439287319779396, 0.1868235021829605, + 0.005632609128952026, -0.10385553538799286, 1.2077943086624146, -0.07107469439506531, 1.771382212638855, + 0.3696843981742859, -0.31587034463882446, 0.0002820117224473506, 0.055834002792835236, + -0.7621694803237915, -3.773604393005371, -0.12602387368679047, 0.8626934289932251, -4.139935493469238, + -0.08643748611211777, -2.25795841217041, -0.025201046839356422, -0.28647178411483765, -0.5088312029838562, + -2.3566224575042725, -0.20447342097759247, 0.3922976553440094, 0.047735944390296936, -0.09984598308801651, + 4.436963081359863, 0.17725177109241486, 0.01968466490507126, -0.4080508351325989, 0.600350558757782, + -0.1489681750535965, -2.3178586959838867, 0.010645782575011253, -0.5052445530891418, 0.12876634299755096, + 2.72904109954834, 0.007315368857234716, 0.503023624420166, 0.9695355892181396, 0.4959081709384918, + -3.562389612197876, 0.3780525028705597, -0.194877028465271, -1.0815603733062744, 0.6436595320701599, + -0.10088582336902618, -0.06308454275131226, -3.7394943237304688, -0.0011674398556351662, + -0.19378826022148132, -0.2329375147819519, 0.029814809560775757, 0.20438098907470703, + -0.23114298284053802, 0.026816120371222496, -7.350013256072998, -0.011900502257049084, 2.0180928707122803, + 0.20987474918365479, -3.209254503250122, -7.5496602058410645, 0.232008695602417, 0.0027162893675267696, + -0.4211888611316681, 0.287914901971817, 0.028367964550852776, 0.015583046711981297, 0.07393462210893631, + 0.6514078974723816, -0.04090245068073273, -0.004561522509902716, 0.2931022346019745, -0.4355356991291046, + -0.1867547184228897, -5.984931945800781, 0.044270407408475876, -0.35987964272499084, + -0.033762961626052856, -0.2677021622657776, -0.013161826878786087, -0.010206296108663082, + -4.528798580169678, 0.4174078106880188, 0.12906667590141296, 0.04690857604146004, -0.08034832775592804, + 1.5188398361206055, 1.3247699737548828, 0.011872933246195316, 0.055544108152389526, 0.0025585023686289787, + -7.696174621582031, 0.030730921775102615, 0.039231084287166595, -0.4407111704349518, -0.3110845386981964, + 2.284346342086792, -0.027610689401626587, 0.09054349362850189, 1.7885178327560425, -0.11802572757005692, + 0.03795969486236572, 2.373623847961426, 0.11311819404363632, 0.009557336568832397, -0.02887658029794693, + -0.28853726387023926, -0.17708882689476013, -0.22821268439292908, 0.0237746462225914, 3.257477283477783, + 0.2507217526435852, 0.17421714961528778, -0.12231585383415222, 0.18179824948310852, -0.3428541123867035, + 0.024907970800995827, 0.2441745400428772, 1.13312828540802, -0.0009440237190574408, -0.594701886177063, + -0.008615869097411633, 4.071537017822266, 0.6198470592498779, -0.3097928464412689, -0.4404515027999878, + -7.008431911468506, 0.024559520184993744, 0.1267288327217102, 0.2140975296497345, 0.5778637528419495, + -0.03296203166246414, 0.8842242360115051, 0.16367295384407043, -0.3035202920436859, -0.09384048730134964, + 0.6805808544158936, -0.2706672251224518, -1.429656982421875, -0.1497703641653061, 0.4302230775356293, + -1.864505648612976, 0.01007054653018713, 0.23598365485668182, -0.08086620271205902, 0.001842734171077609, + 0.08458849042654037, 0.3059651553630829, -0.06515960395336151, -3.803208589553833, -0.41865429282188416, + 0.2828770875930786, 3.459416151046753, -0.00129605398979038, -9.578699111938477, -0.06560757756233215, + 0.026055261492729187, 0.0672057718038559, 0.08423102647066116, -1.3624160289764404, 0.013521464541554451, + -0.027731282636523247, -0.9650477766990662, -0.012694457545876503, -0.2116907835006714, + -0.10714730620384216, 0.0034909709356725216, 1.5338910818099976, 0.0006434338865801692, + 0.1618947833776474, -0.10659407079219818, -6.774624347686768, -0.08567759394645691, 0.5162889361381531, + 0.11074300855398178, -0.09961605817079544, -0.005474632140249014, 0.1132681593298912, 0.10878968983888626, + -0.4140564203262329, -6.274385452270508, -3.410104274749756, -0.2155490219593048, 0.13330507278442383, + -0.2973288297653198, -0.5738739371299744, 0.3465871810913086, -0.2567448318004608, -0.13507360219955444, + -0.014550707302987576, 0.039058394730091095, -0.25891509652137756, -0.30598220229148865, + -0.14163219928741455, 1.2217881679534912, -0.2967555820941925, 0.024605438113212585, -0.03864026814699173, + -1.4379907846450806, -3.0257911682128906, 10.609665870666504, -0.0002576113329268992, 0.1658751666545868, + -0.01822504773736, 0.1141287237405777, -0.3072766363620758, 2.9927172660827637, 0.42983293533325195, + 0.9799204468727112, -0.007520963903516531, 3.565046787261963, -0.18206597864627838, 1.1247198581695557, + -0.0011717785382643342, 0.0026591955684125423, 3.689824104309082, 0.03598639369010925, + -0.09997520595788956, 0.06576227396726608, -1.7916548252105713, 0.030312752351164818, -0.4527510106563568, + -0.26613515615463257, 0.025749003514647484, -0.17866003513336182, 0.18729515373706818, + 0.003528681118041277, -0.1579633355140686, 1.070467472076416, -0.20637144148349762, -0.10882926732301712, + -6.439236640930176, -0.25033196806907654, -0.26708030700683594, 0.036800775676965714, -1.9130735397338867, + 0.11082696169614792, 0.10686857253313065, 7.136363506317139, -2.1805343627929688, 0.002802944276481867, + -1.0081117153167725, -0.08366546779870987, -0.07263432443141937, -0.011199882254004478, + -0.015524221584200859, -0.008838756941258907, -0.005488056223839521, 0.6502953767776489, + -0.010726823471486568, 0.41685575246810913, -0.23590049147605896, -0.0868658497929573, + -0.07914192229509354, 0.22732190787792206, 0.0985199362039566, 0.013477811589837074, 0.5970719456672668, + -0.12020514905452728, -0.0009808604372665286, 0.1139480322599411, -0.5872443914413452, -5.610537528991699, + 0.14893069863319397, 0.44541916251182556, 0.599539041519165, -0.028194887563586235, -0.42580458521842957, + -0.24352392554283142, 0.25486475229263306, -3.251058578491211, 0.042388759553432465, -0.15446388721466064, + -0.01016155257821083, 0.07647261768579483, 4.707305431365967, -0.0834866315126419, -0.21240641176700592, + 0.34789028763771057, 0.0710633248090744, 0.013448074460029602, -0.18779638409614563, 0.022113602608442307, + -1.8543815612792969, 0.012882213108241558, 0.0508059561252594, -2.1125378608703613, 1.00347900390625, + 0.34287792444229126, 0.023498279973864555, 0.2604916989803314, -0.854418158531189, 0.3368889391422272, + -3.5361156463623047, -0.3238249719142914, 0.09940877556800842, 0.011137581430375576, -0.09505806118249893, + 1.4575674533843994, 0.18798890709877014, 0.13481135666370392, -3.1009016036987305, -0.0046508763916790485, + -0.002944883657619357, 0.008598391897976398, -0.05753857269883156, -0.007956058718264103, + 0.7023902535438538, -2.114570140838623, 0.6187217235565186, 4.448208808898926, 2.5069539546966553, + -0.10476846992969513, -0.04466601461172104, 0.32297447323799133, 0.06604880094528198, + -0.0016604098491370678, -2.8530216217041016, -1.2369211912155151, -0.02766953594982624, + -0.025159431621432304, 0.0029653196688741446, 0.04569535329937935, 0.03927958756685257, + -0.0021295847836881876, 0.024881726130843163, 0.028491219505667686, -0.0042065782472491264, + -0.05266435816884041, -0.08988969027996063, -0.04083021357655525, -0.040847159922122955, + 3.154191732406616, -0.06132543459534645, -0.7507759928703308, -0.029571423307061195, 0.03537856787443161, + -1.4017058610916138, 0.3888748586177826, 0.9719987511634827, -0.010947618633508682, 3.847195863723755, + 1.015498161315918, 0.012234801426529884, -0.3849196434020996, 0.5072981119155884, -0.07829593122005463, + 0.2524659037590027, -0.13102610409259796, 0.020525088533759117, 0.15267324447631836, -0.11044808477163315, + 0.008630136027932167, 0.0009689829312264919, 2.615210771560669, 0.3638320863246918, 0.2452821582555771, + 0.01092306338250637, 0.03127167001366615, -3.899691104888916, 0.16573800146579742, -0.06733611971139908, + -0.39246127009391785, 5.207749843597412, 0.05021298676729202, 0.17778877913951874, -1.4260956048965454, + 0.19870443642139435, -0.27705708146095276, -0.11092191934585571, -0.09528861939907074, + -0.5703115463256836, 0.5077508687973022, 0.3938415050506592, 0.47991737723350525, 0.7821948528289795, + -0.2891596853733063, -0.3829837143421173, -0.010832893662154675, 0.15224608778953552, + -0.00014581253344658762, 0.00025647180154919624, 0.02536843903362751, -0.06366542726755142, + -8.023703575134277, 0.027589797973632812, -0.1799485832452774, -0.2505863904953003, -0.3841714859008789, + -0.00031740960548631847, -0.04642002657055855, 0.38759565353393555, -0.05341910943388939, + -0.37632811069488525, 0.6983012557029724, -0.10781889408826828, -0.0007781427120789886, + -0.0877101942896843, -0.5221861600875854, -0.07871037721633911, -2.2496471405029297, + -0.042697690427303314, 2.38197922706604, 0.035262834280729294, 0.0695495679974556, 1.6927565336227417, + -10.396214485168457, 0.05338706448674202, -2.813828468322754, -3.691652536392212, -0.008508461527526379, + 1.570121169090271, -0.4011033773422241, -0.24479898810386658, 0.30835238099098206, 0.1998486965894699, + 0.0945337787270546, 0.39656326174736023, -0.23758645355701447, 0.1661674976348877, -0.04912934452295303, + 0.024212513118982315, -1.0319569110870361, 0.04704924300312996, -0.058226123452186584, + -1.6492913961410522, 0.6406868100166321, -0.005447663366794586, 0.19865849614143372, -0.3373563289642334, + -0.03675329312682152, -0.19241032004356384, -0.43262550234794617, -0.08300381153821945, + -0.014068910852074623, 0.11309102177619934, -0.02719084918498993, 0.2096254676580429, + -0.02292095310986042, 0.4072689712047577, 0.0003724964044522494, 0.20711149275302887, 1.0793871879577637, + 0.06120060756802559, 0.11688049882650375, -0.0023522432893514633, -0.9283630847930908, 2.477475881576538, + 0.26047614216804504, 0.7143173813819885, -1.4795730113983154, -0.15119962394237518, -1.4587875604629517, + -0.03378799930214882, -0.3518248498439789, 0.1747346669435501, 0.002720446093007922, 0.865147590637207, + 0.015568590722978115, 0.1952929049730301, 0.1818414330482483, -1.4265116453170776, 0.2012042999267578, + -0.2151491343975067, 0.11098571866750717, -0.16003955900669098, 7.798532962799072, 0.299221396446228, + -1.0280503034591675, -0.1838797777891159, 0.005458994302898645, -0.13420982658863068, + -0.06905319541692734, -1.8678100109100342, 0.40917104482650757, -0.09650467336177826, 0.2953720986843109, + 0.008414564654231071, 0.1998010128736496, 0.34882158041000366, -0.17196929454803467, 0.031611330807209015, + 0.08629407733678818, -0.1856321394443512, -0.22879824042320251, -0.09241079539060593, 0.2628664970397949, + 0.03050280548632145, 0.15829861164093018, -0.06391621381044388, -0.01048242673277855, + -0.010927671566605568, 1.0013593435287476, -0.15796290338039398, -0.10746872425079346, + -0.0013137627393007278, -0.2024063915014267, 0.0005700114998035133, 0.03609214723110199, + -0.4168614447116852, 0.12660957872867584, -0.005800928454846144, -0.5319929718971252, 0.32967525720596313, + 0.028021158650517464, -1.217489242553711, -0.09096843004226685, 2.344956636428833, 5.432365894317627, + 0.7993219494819641, -0.15543963015079498, -0.0007157247746363282, -0.08481337875127792, + -0.12131065130233765, 1.1516586542129517, -0.01504613272845745, 0.03704383224248886, 0.004402496851980686, + -0.581766664981842, 0.07592090964317322, 0.3745843768119812, -0.4989067614078522, 0.04438084363937378, + 3.175107479095459, -4.4077911376953125, -0.002988559426739812, 1.0332038402557373, 0.006027049385011196, + -0.0018258332274854183, 2.3033533096313477, -0.10134122520685196, 0.02520211599767208, + 0.005497010890394449, 0.0003968894889112562, -0.00029831190477125347, 1.2718113660812378, + -0.34650272130966187, 9.8225736618042, -6.33831787109375, -0.9639870524406433, -4.028343677520752, + 0.016925739124417305, -0.0449683852493763, 0.6271276473999023, -0.13903772830963135, -0.06179821118712425, + 5.860689640045166, -0.005071749445050955, 0.5026626586914062, 0.3309268057346344, 2.2567200660705566, + -4.23521089553833, -0.01613122597336769, -0.02665529027581215, 0.04668727517127991, 3.150425434112549, + -0.0052042026072740555, 1.067151427268982, 0.025044972077012062, -1.325771689414978, -0.1094195619225502, + 0.26904425024986267, 0.6204038262367249, 0.006285298615694046, 0.002915250603109598, -0.6165238618850708, + -4.090943813323975, 2.8669519424438477, -0.09453117847442627, -0.09316729754209518, 0.034191157668828964, + -6.707476615905762, -0.20231620967388153, -1.6191682815551758, 2.0373117923736572, -0.10501966625452042, + 0.0019581259693950415, 0.21420015394687653, 0.0156276635825634, 7.224427223205566, 0.1236666664481163, + 0.294806569814682, -0.0061331382021307945, -0.10612531006336212, -0.8333144187927246, + -0.001029952079989016, 0.38204053044319153, -0.03597458079457283, -1.41422438621521, -0.2833155691623688, + -0.0006075138808228076, -0.3701440095901489, 0.1309424191713333, 0.06839437037706375, + -0.0017361472127959132, -1.69569993019104, -0.20629459619522095, -0.5999218225479126, 0.114132359623909, + 6.6828436851501465, -0.36263933777809143, 0.41539111733436584, 0.022192703559994698, -0.06610587984323502, + -1.683022141456604, -0.2835130989551544, 0.27643388509750366, -0.6247501373291016, -1.421617865562439, + -0.08159351348876953, 0.005017416086047888, -2.026592493057251, 0.0009393739746883512, 1.760980486869812, + 0.00019237841479480267, 0.0022294363006949425, 0.22415778040885925, -0.09657209366559982, + -3.056180953979492, -0.24515365064144135, -1.7638490200042725, -1.900456428527832, 1.7747641801834106, + -0.9473960399627686, 0.27619242668151855, -0.11893711239099503, 0.7769895792007446, 0.09835439175367355, + 0.0019296495011076331, -0.043601375073194504, 0.03626292571425438, 0.1591210663318634, + 0.45964139699935913, -0.06853707879781723, -2.4563350677490234, -0.13421472907066345, + 0.040955424308776855, -0.2855738699436188, 0.11433675140142441, 0.00306497560814023, -0.48573875427246094, + -0.046301428228616714, 0.6907474994659424, -3.983771800994873, 2.3131954669952393, 0.05256381258368492, + -0.0911293551325798, -1.8945766687393188, 0.03453084081411362, 1.8747694492340088, 0.3433213233947754, + -1.1485600471496582, 1.6418366432189941, -0.13894057273864746, -3.1275031566619873, -0.9726752638816833, + -0.0012102212058380246, 3.898921251296997, 0.2646528482437134, 0.01665414497256279, -0.06312943994998932, + -3.0655016899108887, 0.024803519248962402, -0.25584203004837036, -1.3387784957885742, + -0.03684002161026001, -0.524848461151123, -0.9969499707221985, -1.8777778148651123, -0.14820723235607147, + -8.182234764099121, 0.015234949067234993, -0.010302969254553318, 0.10785042494535446, + -0.02237902209162712, -2.500221014022827, -0.006121153011918068, 0.054380882531404495, 2.607618808746338, + -0.48403942584991455, 1.7271841764450073, -0.054084569215774536, 0.04733904451131821, 0.29113033413887024, + 0.3090323507785797, -0.4069989323616028, 0.827186644077301, -0.8676308393478394, -0.18980173766613007, + 0.017093969509005547, 0.05046425387263298, 0.025303032249212265, -0.9938563704490662, 0.0307749193161726, + -0.003506980137899518, -2.145794153213501, 0.08889109641313553, -0.2744760513305664, 0.02533753775060177, + -0.008416163735091686, 1.6139867305755615, -6.39102840423584, -4.842134952545166, -0.7291613817214966, + 0.9694556593894958, 0.07247202843427658, 0.005149913020431995, 0.0029090321622788906, -0.1867554932832718, + 0.0015806574374437332, -0.04847263917326927, -0.18512502312660217, -0.04184968024492264, + -1.2331782579421997, -0.6159178018569946, 0.025481248274445534, 0.11850030720233917, -0.2734290063381195, + 0.26392263174057007, -0.2278929203748703, -2.4300358295440674, 0.0007563972030766308, 1.2603007555007935, + -0.009525062516331673, -2.5598459243774414, -0.1015859916806221, -0.3136966824531555, + -0.0023580349516123533, 3.0281076431274414, 0.0851983055472374, 0.18700700998306274, + -0.008541906252503395, -0.007119827438145876, 0.42274990677833557, -0.06235692277550697, + 0.3246764540672302, 0.047069381922483444, 0.00011004792031599209, -0.49105241894721985, + 0.041874051094055176, 0.013326031155884266, -2.525364875793457, 0.5126351714134216, -0.01582668349146843, + -0.3391125500202179, -0.1049877479672432, -0.36534854769706726, 0.027926098555326462, + 0.004374756012111902, -0.10876958072185516, 0.579942524433136, 0.34367814660072327, 0.12710949778556824, + -0.28762391209602356, 0.028134111315011978, 0.001072783605195582, -0.430772066116333, 0.21052536368370056, + 0.09690351784229279, 0.000786184798926115, 0.06906910240650177, -3.5896573066711426, 0.24118882417678833, + -3.176041841506958, 1.3121603727340698, -0.40836477279663086, -7.590582370758057, -1.9390276670455933, + -0.06406442821025848, 0.00011302023631287739, 0.013246525079011917, 0.21886053681373596, + 0.090825155377388, -0.06342892348766327, -0.14027893543243408, 0.017751706764101982, 0.11045858263969421, + -0.05397825688123703, 0.2152465432882309, 0.14184458553791046, -1.6443814039230347, -1.023624300956726, + 0.050706081092357635, -0.8185511231422424, -0.009972692467272282, -1.6231411695480347, + -0.010527506470680237, 1.5382870435714722, -2.6943516731262207, 0.965884804725647, -0.5423170924186707, + -2.0661613941192627, -0.4436858892440796, 0.0058816042728722095, -0.665194034576416, 0.8273401260375977, + 0.10996203124523163, -0.1316700130701065, 0.027179520577192307, -0.2735114097595215, -0.10301132500171661, + -1.906333565711975, -0.32074108719825745, 0.4478001892566681, -1.1052520275115967, 0.009423047304153442, + 0.5322814583778381, -0.004648196045309305, -0.009632693603634834, -0.7735386490821838, + 0.005249344743788242, 0.11850841343402863, -0.0034776863176375628, -0.1439099758863449, + 0.2767007648944855, -2.8716399669647217, -0.16290035843849182, -0.1801692247390747, 0.19117145240306854, + -0.7634338736534119, -0.29985561966896057, 0.009378351271152496, -0.6186265349388123, + -0.13845475018024445, 0.03558935597538948, -0.20145508646965027, -0.5337783694267273, 0.28876203298568726, + -0.5732369422912598, 0.03304499760270119, 0.8687714338302612, -0.2524224817752838, 0.4371426999568939, + 0.03568745777010918, 0.4382450580596924, 0.03245728462934494, -0.14247629046440125, 0.7598915696144104, + 0.30114904046058655, -0.21331092715263367, -0.0028205476701259613, 0.09227168560028076, + 0.008056613616645336, 1.635034203529358, -0.166751429438591, -0.020675446838140488, -1.2244166135787964, + -0.10547340661287308, 2.802537441253662, -0.004014655947685242, 3.690307140350342, 0.0017954192589968443, + -0.45281466841697693, 0.020796259865164757, 1.5265557765960693, 0.20084713399410248, -0.21376214921474457, + -0.025406286120414734, 1.9211277961730957, -0.7583361268043518, 4.267587661743164, -4.551294803619385, + -0.08887865394353867, 0.07695532590150833, -0.17959536612033844, 0.5096666216850281, -9.957548141479492, + -0.11618410050868988, 0.09543278813362122, 0.270590603351593, 0.024046115577220917, -0.245524600148201, + -0.307966023683548, 2.2781827449798584, -0.14958485960960388, -0.06977607309818268, 2.3428077697753906, + 0.8067795038223267, 0.9448233842849731, 0.35110360383987427, 0.4814533293247223, 0.00956026278436184, + -0.03395213186740875, 0.2255835384130478, 0.4806722402572632, 0.0005861452082172036, -0.5671629309654236, + 0.5004423260688782, 7.04985237121582, -0.41439759731292725, -0.2847898304462433, -0.10965365916490555, + 0.3427604138851166, 0.12897160649299622, -0.6046913266181946, -0.1840457171201706, 0.002393739065155387, + 0.41798320412635803, 3.0662004947662354, -0.0002512158825993538, 3.1039047241210938, -0.6795744895935059, + 0.3395420014858246, 0.33144497871398926, 6.939244270324707, 0.0011752373538911343, 0.09660591185092926, + 0.4894343912601471, 4.00507116317749, 0.005761822685599327, 0.5680277943611145, 3.315598726272583, + -0.5373169779777527, 0.44444069266319275, -0.0027646978851407766, 1.6829670667648315, -2.6327450275421143, + 7.026787757873535, 1.5361366271972656, -0.3418486714363098, -0.21611177921295166, -0.05756537616252899, + -0.023443035781383514, 0.23109422624111176, 0.21568432450294495, 1.236294150352478, -0.4581376612186432, + -4.188883304595947, 0.28338590264320374, -0.04076884314417839, -0.0198111142963171, -1.1872543096542358, + -0.062372151762247086, 2.7870607376098633, -0.6517040729522705, -1.7516529560089111, -1.8091800212860107, + -0.15286022424697876, 0.09354468435049057, 0.11334053426980972, -1.8259668350219727, + -0.017136069014668465, 0.12429703027009964, 0.00773972412571311, -1.0446529388427734, 0.2356342375278473, + 0.9886437654495239, 0.18150633573532104, 0.4682118594646454, -0.11415664851665497, 0.003153775818645954, + 0.03332524746656418, -1.1180988550186157, 4.163827896118164, -0.18159173429012299, 0.0999181717634201, + 2.058738946914673, 4.2595014572143555, 0.010485258884727955, 0.16270849108695984, -0.9842506051063538, + -0.0003203578235115856, 6.040156364440918, 1.308574914932251, -0.36853301525115967, -0.29669252038002014, + 0.2741331160068512, -0.3422505855560303, -0.7587988972663879, 2.9686927795410156, 0.8209773898124695, + -0.0007857486489228904, 0.02395622618496418, 0.05102493613958359, 0.15041151642799377, + -0.002741719363257289, -4.980742454528809, 0.2880830764770508, 0.24828404188156128, 0.15224777162075043, + 0.13059845566749573, 1.5662918090820312, -0.8474074006080627, 0.08810389041900635, -0.2630780041217804, + 0.43268874287605286, -0.001932788989506662, 0.012391841039061546, 3.4719245433807373, + -0.0024825239088386297, -0.11508434265851974, 0.40480512380599976, 0.4639693796634674, 1.0610097646713257, + -0.3626452386379242, -0.18480324745178223, 0.2711804509162903, 0.21260038018226624, -0.02785545215010643, + -0.11340377479791641, -0.027071641758084297, -0.22035427391529083, -0.10667331516742706, + 0.16908612847328186, 0.10428506135940552, 0.1233300194144249, -0.06643304973840714, 0.2004912942647934, + 0.09342359751462936, 0.03175133094191551, -1.5805491209030151, -0.4311752915382385, 0.5455954074859619, + 0.15516425669193268, -0.04940091818571091, -0.1447768211364746, 0.21044854819774628, 4.243553161621094, + 0.2046128362417221, -0.2688649892807007, 8.694140434265137, 1.6753796339035034, 0.05555065721273422, + -0.16497156023979187, 0.1828891634941101, 1.505862832069397, 0.08261094242334366, 3.104039430618286, + 5.931700706481934, 0.4487259089946747, -0.011016261763870716, 0.012441087514162064, -0.5082470774650574, + -0.11641799658536911, 0.01356368139386177, -0.01659458689391613, 7.667582988739014, -0.346441388130188, + -5.981383323669434, 7.6806206703186035, -0.0383722260594368, 0.4603561460971832, -0.16010521352291107, + -3.173022985458374, -0.4581749737262726, 0.06485684961080551, -0.4382486939430237, 0.25564998388290405, + 0.21537242829799652, 8.642731666564941, -0.00621763477101922, 1.9488575458526611, -0.030582817271351814, + 0.00024394349020440131, -0.015434199944138527, 1.1591683626174927, 0.29453498125076294, + -0.06397286057472229, 0.2931598126888275, 3.056126594543457, 0.35364240407943726, -0.07242457568645477, + 0.19602720439434052, 3.9850471019744873, -0.141653373837471, 0.7269659042358398, 0.020487932488322258, + -1.1716344356536865, -0.13872867822647095, 0.004658155608922243, 0.0002524183946661651, + -0.3965473771095276, 0.058615222573280334, -0.005210256204009056, -6.64661169052124, + -0.003978024236857891, -0.11946024745702744, -0.15917553007602692, -2.0323069095611572, 3.694988250732422, + -0.2484176903963089, 1.117063283920288, 0.04259220138192177, 5.167333126068115, -4.669308185577393, + 0.0845528095960617, 0.011936561204493046, 5.875088691711426, -0.032051537185907364, -1.2815053462982178, + -0.010812301188707352, -0.021371034905314445, -0.0037929099053144455, 0.09110360592603683, + -0.2871546745300293, -0.2895969748497009, 0.2745248079299927, -2.462489366531372, -1.4751100540161133, + 0.6493473052978516, 0.105231374502182, -0.11311662197113037, 5.5921311378479, -0.11590596288442612, + -4.284017562866211, -3.0435032844543457, 0.2767241597175598, -0.2098703682422638, -0.008056361228227615, + 3.738365888595581, 0.03918733820319176, 0.5250627994537354, -0.03754068538546562, -1.1869362592697144, + 0.0016376300482079387, -0.19201913475990295, -0.12353025376796722, -0.038338642567396164, + -0.5192811489105225, -0.07935589551925659, 2.1531660556793213, -0.002360888756811619, 0.3615436553955078, + -1.499021291732788, 0.6402538418769836, -2.886809825897217, 2.502922296524048, -0.0014745928347110748, + -0.09228605031967163, -0.14953434467315674, 0.2779182493686676, 2.071781635284424, -0.2248198240995407, + 0.5830495357513428, -0.1257641464471817, -0.06734845042228699, 0.003910396713763475, -1.285413146018982, + -5.392889976501465, -0.0003311980399303138, 8.632763862609863, 0.1819709688425064, -0.013432486914098263, + -0.019152071326971054, -0.026376325637102127 + ], + "dims": [1, 1, 1280], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } @@ -251,7 +314,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", @@ -285,5 +347,50 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 928192bb219f2..2e8eaaba191d0 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -125,6 +125,72 @@ } ] }, + { + "name": "conv with bias addition C - NHWC", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 3, 4, 5], + "dims": [2, 1, 2, 2], + "type": "float32" + }, + { + "data": [5, 6], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [15, 46, 31, 102, 47, 158], + "dims": [3, 2, 1, 1], + "type": "float32" + } + ] + }, + { + "name": "inChannel = 3, outChannel = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [4, 3, 2, 2], + "type": "float32" + }, + { + "data": [5, 6, 7, 8], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [360, 334, 271, 323, 909, 963, 1024, 1028, 683, 655, 576, 650, 473, 508, 570, 677], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "conv - group - A", "operator": "Conv", diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 460122b4e085c..35888e2fc3709 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -85,5 +85,34 @@ ] } ] + }, + { + "name": "Expand 5D - float32", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand 5 - float32", + "inputs": [ + { + "data": [1], + "dims": [1, 1, 1, 1, 1], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 6], + "dims": [5], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 1, 1, 1, 1], + "dims": [1, 1, 1, 1, 6], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000000..05687bd482e24 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80c36..beef154a29932 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice", diff --git a/js/web/test/data/ops/softmax.jsonc b/js/web/test/data/ops/softmax.jsonc index 85d4096ee0493..98573fcd73ba2 100644 --- a/js/web/test/data/ops/softmax.jsonc +++ b/js/web/test/data/ops/softmax.jsonc @@ -5,7 +5,7 @@ "attributes": [], "cases": [ { - "name": "T[2,4]", + "name": "T[2,2]", "inputs": [ { "data": [1.0, 2.0, 3.0, 4.0], @@ -22,5 +22,32 @@ ] } ] + }, + { + "name": "Softmax with no attributes", + "operator": "Softmax", + "attributes": [], + "cases": [ + { + "name": "T[2, 2, 2]", + "inputs": [ + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + "dims": [2, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973, 0.2689414322376251, + 0.7310585975646973, 0.2689414322376251, 0.7310585975646973 + ], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 285d14018e74d..e1edfa7e41513 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -166,5 +166,29 @@ ] } ] + }, + { + "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..ad97177bb101b --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,84 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + }, + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 1] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, false, true], + "dims": [3, 1], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..a313adef7151b 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -885,10 +885,10 @@ // // "test_qlinearmatmul_3D", // // "test_quantizelinear_axis", // // "test_quantizelinear", - // "test_range_float_type_positive_delta_expanded", - // "test_range_float_type_positive_delta", - // "test_range_int32_type_negative_delta_expanded", - // "test_range_int32_type_negative_delta", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", "test_reciprocal_example", "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", @@ -1336,6 +1336,10 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", + "batch-norm.jsonc", + "bias-add.jsonc", + "bias-split-gelu.jsonc", "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", @@ -1360,6 +1364,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", @@ -1386,7 +1391,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 49d0ac225be2f..24ab0694b32b8 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Load onnxruntime-web and testdata-config. +// Load onnxruntime-common and testdata-config. // NOTE: this need to be called before import any other library. -const ort = require('..'); +import * as ort from 'onnxruntime-common'; + const ORT_WEB_TEST_CONFIG = require('./testdata-config.json') as Test.Config; import * as platform from 'platform'; @@ -57,6 +58,9 @@ if (options.globalEnvFlags) { if (flags.webgpu?.profilingMode !== undefined) { ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; } + if (flags.webgpu?.validateInputContent !== undefined) { + ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; + } } // Set logging configuration diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 46d80a9f56f35..29acc07e118f9 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators'; import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {Tensor} from '../lib/onnxjs/tensor'; import {ProtoUtil} from '../lib/onnxjs/util'; -import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; +import {createView} from '../lib/wasm/jsep/tensor-view'; +import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; import {base64toBuffer, createMockGraph, readFile} from './test-shared'; import {Test} from './test-types'; @@ -136,8 +137,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -146,8 +147,14 @@ async function initializeSession( preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; - const sessionConfig = - {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile}; + const sessionConfig = { + ...sessionOptions, + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + let session: ort.InferenceSession; try { @@ -157,7 +164,10 @@ async function initializeSession( session = await ort.InferenceSession.create(modelFilePath, sessionConfig); } } catch (e) { - Logger.error('TestRunner', `Failed to load model from file: ${modelFilePath}. Error: ${inspect(e)}`); + Logger.error( + 'TestRunner', + `Failed to load model from file: ${modelFilePath}. ` + + `Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`); throw e; } @@ -181,6 +191,7 @@ export class ModelTestContext { readonly session: ort.InferenceSession, readonly backend: string, readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, ) {} @@ -232,8 +243,8 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); - const session = - await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache); + const session = await initializeSession( + modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); const initEnd = now(); for (const testCase of modelTest.cases) { @@ -244,6 +255,7 @@ export class ModelTestContext { session, modelTest.backend!, {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, + modelTest.ioBinding, profile, ); } finally { @@ -481,6 +493,130 @@ export class TensorResultValidator { } } +function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { + if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, + mappedAtCreation: true + }); + const arrayBuffer = gpuBuffer.getMappedRange(); + new Uint8Array(arrayBuffer) + .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + gpuBuffer.unmap(); + + // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? + + return ort.Tensor.fromGpuBuffer( + gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); +} + +function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { + if (!isGpuBufferSupportedType(type)) { + throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); + } + + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; + const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(size / 16) * 16 + }); + + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: type, + dims, + dispose: () => gpuBuffer.destroy(), + download: async () => { + const stagingBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + size: gpuBuffer.size + }); + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); + device.queue.submit([encoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size); + stagingBuffer.destroy(); + + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; + } + }); +} + +export async function sessionRun(options: { + session: ort.InferenceSession; feeds: Record; + outputsMetaInfo: Record>; + ioBinding: Test.IOBindingMode; +}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { + const session = options.session; + const feeds = options.feeds; + const fetches: Record = {}; + + // currently we only support IO Binding for WebGPU + // + // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. + // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // in 'gpu-device' binding mode, outputs are not pre-allocated. + const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; + const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + try { + if (shouldUploadInput) { + // replace the CPU tensors in feeds into GPU tensors + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } + } + } + + if (shouldUploadOutput) { + for (const name in options.outputsMetaInfo) { + if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { + const {type, dims} = options.outputsMetaInfo[name]; + fetches[name] = createGpuTensorForOutput(type, dims); + } + } + } + + const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); + const outputs = await ( + shouldUploadOutput ? session.run(feeds, fetches) : + session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + + // download each output tensor if needed + for (const name in outputs) { + if (Object.hasOwnProperty.call(outputs, name)) { + const tensor = outputs[name]; + // Tensor.getData(true) release the underlying resource + await tensor.getData(true); + } + } + + return [start, end, outputs]; + } finally { + // dispose the GPU tensors in feeds + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + const tensor = feeds[name]; + tensor.dispose(); + } + } + } +} + /** * run a single model test case. the inputs/outputs tensors should already been prepared. */ @@ -491,12 +627,11 @@ export async function runModelTestSet( const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; + const outputsMetaInfo: Record = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - const start = now(); - Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await context.session.run(feeds); - const end = now(); - Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + const [start, end, outputs] = + await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -575,6 +710,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; + readonly ioBindingMode: Test.IOBindingMode; constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; @@ -713,6 +849,7 @@ export class ProtoOpTestContext { model.graph.name = test.name; this.backendHint = test.backend!; + this.ioBindingMode = test.ioBinding; this.loadedData = onnx.ModelProto.encode(model).finish(); // in debug mode, open a new tab in browser for the generated onnx model. @@ -729,8 +866,11 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create( - this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); + this.session = await ort.InferenceSession.create(this.loadedData, { + executionProviders: [this.backendHint], + preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + ...this.sessionOptions + }); } async dispose(): Promise { @@ -739,10 +879,11 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise { + session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator): Promise { const feeds: Record = {}; - const fetches: string[] = []; - testCase.inputs!.forEach((input, i) => { + const fetches: Record> = {}; + testCase.inputs.forEach((input, i) => { if (input.data) { let data: number[]|BigUint64Array|BigInt64Array = input.data; if (input.type === 'uint64') { @@ -756,7 +897,7 @@ async function runProtoOpTestcase( const outputs: ort.Tensor[] = []; const expectedOutputNames: string[] = []; - testCase.outputs!.forEach((output, i) => { + testCase.outputs.forEach((output, i) => { if (output.data) { let data: number[]|BigUint64Array|BigInt64Array = output.data; if (output.type === 'uint64') { @@ -766,11 +907,11 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches.push(`output_${i}`); + fetches[`output_${i}`] = {dims: output.dims, type: output.type}; } }); - const results = await session.run(feeds, fetches); + const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); @@ -821,7 +962,8 @@ async function runOpTestcase( export async function runOpTest( testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { if (context instanceof ProtoOpTestContext) { - await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint)); + await runProtoOpTestcase( + context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); } else { await runOpTestcase( context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); diff --git a/js/web/test/test-shared.ts b/js/web/test/test-shared.ts index 68d7852ce86da..7c327e7c97ac4 100644 --- a/js/web/test/test-shared.ts +++ b/js/web/test/test-shared.ts @@ -2,8 +2,7 @@ // Licensed under the MIT License. import * as base64 from 'base64-js'; -import * as fs from 'fs'; -import {promisify} from 'util'; +import * as fs from 'node:fs/promises'; import {Attribute} from '../lib/onnxjs/attribute'; import {Graph} from '../lib/onnxjs/graph'; @@ -19,7 +18,7 @@ export function bufferToBase64(buffer: Uint8Array): string { export async function readFile(file: string) { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - return promisify(fs.readFile)(file); + return fs.readFile(file); } else { // browser const response = await fetch(file); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 1f95d1cd8e682..5bdc8d84cc7a5 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -43,6 +43,18 @@ export declare namespace Test { */ export type PlatformCondition = string; + /** + * The IOBindingMode represents how to test a model with GPU data. + * + * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation` + * will not be set. + * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `gpu-buffer`. + * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` + * will not be set. + */ + export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export interface ModelTestCase { name: string; dataFiles: readonly string[]; @@ -54,6 +66,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -82,6 +95,7 @@ export declare namespace Test { inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; @@ -114,6 +128,7 @@ export declare namespace Test { wasm: Partial; webgl: Partial; webgpu: Partial; + logLevel?: Env['logLevel']; } /** diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 0fddddf58181c..8c186b9b36451 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -14,7 +14,7 @@ import {conv2d} from './test-conv-utils'; function createRandomArray(size: number): Float32Array { const randomTable = [0, 3, 6, 9, 2, 5, 8, 1, 4, 7]; return new Float32Array( - Array.from({length: size}, (v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); + Array.from({length: size}, (_v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); } interface TestData { inputShape: number[]; diff --git a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts index 0b70144733227..61c21d4b689fb 100644 --- a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts +++ b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts @@ -291,7 +291,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { webglInferenceHandler.session.textureManager.glContext.checkError(); const webglTexture = createTextureFromArray( webglInferenceHandler.session.textureManager.glContext, testData.rawData ? testData.rawData : inputData, - gl.RGBA, inputTextureShape[0], inputTextureShape[1]); + inputTextureShape[0], inputTextureShape[1]); webglInferenceHandler.session.textureManager.glContext.checkError(); const packedShape = inputTextureShape; const textureData = { diff --git a/js/web/test/unittests/backends/webgl/test-utils.ts b/js/web/test/unittests/backends/webgl/test-utils.ts index acb3f0002ce2f..092d63cd2ade4 100644 --- a/js/web/test/unittests/backends/webgl/test-utils.ts +++ b/js/web/test/unittests/backends/webgl/test-utils.ts @@ -4,7 +4,7 @@ import {WebGLContext} from '../../../../lib/onnxjs/backends/webgl/webgl-context'; export function createAscendingArray(size: number): Float32Array { - return new Float32Array(Array.from({length: size}, (v, i) => (i + 1))); + return new Float32Array(Array.from({length: size}, (_v, i) => (i + 1))); } // Returns an array by injecting 3 zeros after every element in the input array to be used for creating unpacked @@ -19,7 +19,7 @@ export function generateArrayForUnpackedTexture(input: Float32Array): Float32Arr // create a webgl texture and fill it with the array content export function createTextureFromArray( - glContext: WebGLContext, dataArray: Float32Array, type: GLenum, width: number, height: number): WebGLTexture { + glContext: WebGLContext, dataArray: Float32Array, width: number, height: number): WebGLTexture { const gl = glContext.gl; // create the texture diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index 0a4d19af9981f..d60d746e9328d 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -1,12 +1,10 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "module": "CommonJS", "downlevelIteration": true, "declaration": true, - "declarationDir": "./types", "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"] }, - "include": ["lib", "script", "test"], + "include": ["lib", "test"], "exclude": ["lib/wasm/proxy-worker"] } diff --git a/js/web/types.d.ts b/js/web/types.d.ts index c6cff64c8a732..b9d12cf47b5c5 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -5,6 +5,26 @@ declare module 'onnxruntime-web' { export * from 'onnxruntime-common'; } +declare module 'onnxruntime-web/experimental' { + export * from 'onnxruntime-web'; +} + +declare module 'onnxruntime-web/wasm' { + export * from 'onnxruntime-web'; +} + +declare module 'onnxruntime-web/wasm-core' { + export * from 'onnxruntime-web'; +} + +declare module 'onnxruntime-web/webgl' { + export * from 'onnxruntime-web'; +} + declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } + +declare module 'onnxruntime-web/training' { + export * from 'onnxruntime-web'; +} diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js deleted file mode 100644 index 81c69ffdcf6bf..0000000000000 --- a/js/web/webpack.config.js +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const webpack = require('webpack'); -const BundleAnalyzerPlugin = require('webpack-bundle-analyzer').BundleAnalyzerPlugin; -const NodePolyfillPlugin = require('node-polyfill-webpack-plugin'); -const TerserPlugin = require('terser-webpack-plugin'); -const minimist = require('minimist'); - -// commandline args -const args = minimist(process.argv); -const bundleMode = args['bundle-mode'] || 'prod'; // 'prod'|'dev'|'perf'|'node'|undefined; -const useAnalyzer = !!args.a || !!args['use-analyzer']; // -a, --use-analyzer -const filter = args.f || args.filter; - -const VERSION = require(path.join(__dirname, 'package.json')).version; -const COPYRIGHT_BANNER = `/*! -* ONNX Runtime Web v${VERSION} -* Copyright (c) Microsoft Corporation. All rights reserved. -* Licensed under the MIT License. -*/`; - -function terserEcmaVersionFromWebpackTarget(target) { - switch (target) { - case 'es5': - return 5; - case 'es6': - case 'es2015': - return 2015; - case 'es2017': - return 2017; - default: - throw new RangeError(`not supported ECMA version: ${target}`); - } -} - -function defaultTerserPluginOptions(target) { - return { - extractComments: false, - terserOptions: { - ecma: terserEcmaVersionFromWebpackTarget(target), - format: { - comments: false, - }, - compress: {passes: 2}, - mangle: {reserved: ['_scriptDir', 'startWorker']} - } - }; -} - -const DEFAULT_BUILD_DEFS = { - DISABLE_WEBGL: false, - DISABLE_WEBGPU: true, - DISABLE_WASM: false, - DISABLE_WASM_PROXY: false, - DISABLE_WASM_THREAD: false, -}; - -// common config for release bundle -function buildConfig({filename, format, target, mode, devtool, build_defs}) { - const config = { - target: [format === 'commonjs' ? 'node' : 'web', target], - entry: path.resolve(__dirname, 'lib/index.ts'), - output: {path: path.resolve(__dirname, 'dist'), filename, library: {type: format}}, - resolve: { - extensions: ['.ts', '.js'], - alias: { - 'util': false, - }, - fallback: { - 'crypto': false, - 'fs': false, - 'path': false, - 'util': false, - 'os': false, - 'worker_threads': false, - 'perf_hooks': false, - } - }, - plugins: [ - new webpack.DefinePlugin({BUILD_DEFS: build_defs}), new webpack.WatchIgnorePlugin({paths: [/\.js$/, /\.d\.ts$/]}) - ], - module: { - rules: [ - {test: /\.ts$/, use: [{loader: 'ts-loader', options: {compilerOptions: {target}}}]}, - {test: /ort-wasm.*\.worker\.js$/, type: 'asset/source'} - ] - }, - mode, - node: false, - devtool - }; - - if (useAnalyzer) { - config.plugins.unshift( - new BundleAnalyzerPlugin({analyzerMode: 'static', reportFilename: `${filename}.report.html`})); - } - - if (mode === 'production') { - config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js'; - config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js'; - config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js'; - - const options = defaultTerserPluginOptions(target); - options.terserOptions.format.preamble = COPYRIGHT_BANNER; - config.plugins.push(new TerserPlugin(options)); - - // add a custom plugin to check whether code contains 'BUILD_DEFS' - config.plugins.push({ - apply: (compiler) => { - compiler.hooks.afterCompile.tap('Check BUILD_DEFS', (compilation) => { - for (const filename of compilation.assetsInfo.keys()) { - if (filename.endsWith('.js')) { - const asset = compilation.assets[filename]; - if (asset) { - const content = asset.source(); - if (typeof content !== 'string') { - throw new Error(`content for target file '${filename}' is not string.`); - } - if (content.includes('DISABLE_WEBGL') || content.includes('DISABLE_WASM') || - content.includes('DISABLE_WASM_PROXY') || content.includes('DISABLE_WASM_THREAD')) { - throw new Error(`target file '${filename}' contains data fields from "BUILD_DEFS".`); - } - } - } - } - }); - } - }); - } else { - config.plugins.push(new webpack.BannerPlugin({banner: COPYRIGHT_BANNER, raw: true})); - } - - return config; -} - -// "ort{.min}.js" config -function buildOrtConfig( - {suffix = '', target = 'es2017', mode = 'production', devtool = 'source-map', build_defs = DEFAULT_BUILD_DEFS}) { - const config = buildConfig({filename: `ort${suffix}.js`, format: 'umd', target, mode, devtool, build_defs}); - // set global name 'ort' - config.output.library.name = 'ort'; - return config; -} - -// "ort-web{.min|.node}.js" config -function buildOrtWebConfig({ - suffix = '', - format = 'umd', - target = 'es2017', - mode = 'production', - devtool = 'source-map', - build_defs = DEFAULT_BUILD_DEFS -}) { - const config = buildConfig({filename: `ort-web${suffix}.js`, format, target, mode, devtool, build_defs}); - // exclude onnxruntime-common from bundle - config.externals = { - 'onnxruntime-common': {commonjs: 'onnxruntime-common', commonjs2: 'onnxruntime-common', root: 'ort'} - }; - // in nodejs, treat as external dependencies - if (format === 'commonjs') { - config.externals.path = 'path'; - config.externals.fs = 'fs'; - config.externals.util = 'util'; - config.externals.worker_threads = 'worker_threads'; - config.externals.perf_hooks = 'perf_hooks'; - config.externals.os = 'os'; - } - return config; -} - -function buildTestRunnerConfig({ - suffix = '', - format = 'umd', - target = 'es2017', - mode = 'production', - devtool = 'source-map', - build_defs = DEFAULT_BUILD_DEFS -}) { - const config = { - target: ['web', target], - entry: path.resolve(__dirname, 'test/test-main.ts'), - output: { - path: path.resolve(__dirname, 'test'), - filename: `ort${suffix}.js`, - library: {type: format}, - devtoolNamespace: '', - }, - externals: { - 'onnxruntime-common': 'ort', - 'fs': 'fs', - 'perf_hooks': 'perf_hooks', - 'worker_threads': 'worker_threads', - '../../node': '../../node' - }, - resolve: { - alias: { - // make sure to refer to original source files instead of generated bundle in test-main. - '..$': '../lib/index' - }, - extensions: ['.ts', '.js'], - fallback: { - './binding/ort-wasm.js': false, - './binding/ort-wasm-threaded.js': false, - './binding/ort-wasm-threaded.worker.js': false - } - }, - plugins: [ - new webpack.DefinePlugin({BUILD_DEFS: build_defs}), - new webpack.WatchIgnorePlugin({paths: [/\.js$/, /\.d\.ts$/]}), - new NodePolyfillPlugin({excludeAliases: ['console', 'Buffer']}), - ], - module: { - rules: [ - {test: /\.ts$/, use: [{loader: 'ts-loader', options: {compilerOptions: {target}}}]}, - {test: /ort-wasm.*\.worker\.js$/, type: 'asset/source'} - ] - }, - mode, - node: false, - devtool, - }; - - if (mode === 'production') { - config.plugins.push(new TerserPlugin(defaultTerserPluginOptions(target))); - } - - return config; -} - -module.exports = () => { - const builds = []; - - switch (bundleMode) { - case 'prod': - builds.push( - // ort.min.js - buildOrtConfig({suffix: '.min'}), - // ort.js - buildOrtConfig({mode: 'development', devtool: 'inline-source-map'}), - // ort.es6.min.js - buildOrtConfig({suffix: '.es6.min', target: 'es6'}), - // ort.es5.min.js - buildOrtConfig({suffix: '.es5.min', target: 'es5'}), - - // ort.wasm.min.js - buildOrtConfig({ - suffix: '.wasm.min', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WEBGL: true, - } - }), - // ort.webgl.min.js - buildOrtConfig({ - suffix: '.webgl.min', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WASM: true, - } - }), - // ort.wasm-core.min.js - buildOrtConfig({ - suffix: '.wasm-core.min', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WEBGL: true, - DISABLE_WASM_PROXY: true, - DISABLE_WASM_THREAD: true, - } - }), - // ort.webgpu.min.js - buildOrtConfig({ - suffix: '.webgpu.min', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WEBGPU: false, - } - }), - - // ort-web.min.js - buildOrtWebConfig({suffix: '.min'}), - // ort-web.js - buildOrtWebConfig({mode: 'development', devtool: 'inline-source-map'}), - // ort-web.es6.min.js - buildOrtWebConfig({suffix: '.es6.min', target: 'es6'}), - // ort-web.es5.min.js - buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), - ); - - case 'node': - builds.push( - // ort-web.node.js - buildOrtWebConfig({suffix: '.node', format: 'commonjs'}), - ); - break; - case 'dev': - builds.push(buildTestRunnerConfig({ - suffix: '.dev', - mode: 'development', - devtool: 'inline-source-map', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WEBGPU: false, - } - })); - break; - case 'perf': - builds.push(buildTestRunnerConfig({ - suffix: '.perf', - build_defs: { - ...DEFAULT_BUILD_DEFS, - DISABLE_WEBGPU: false, - } - })); - break; - default: - throw new Error(`unsupported bundle mode: ${bundleMode}`); - } - - if (filter) { - const filterRegex = new RegExp(filter); - return builds.filter(b => filterRegex.test(b.output.filename)); - } - - return builds; -}; diff --git a/js/webpack.shared.mjs b/js/webpack.shared.mjs deleted file mode 100644 index d1b95722ff4de..0000000000000 --- a/js/webpack.shared.mjs +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -import webpack from 'webpack'; -import TerserPlugin from 'terser-webpack-plugin'; -import {resolve, dirname} from 'node:path'; -import {readFileSync} from 'node:fs'; -import {fileURLToPath} from 'node:url'; - -/** - * ECMAScript version for default onnxruntime JavaScript API builds - */ -export const DEFAULT_ES_VERSION = 'es2017'; - -// how to use "__dirname" in ESM: https://shramko.dev/blog/dirname-error -const __dirname = dirname(fileURLToPath(import.meta.url)); - -const terserEcmaVersionFromWebpackEsVersion = (target) => { - switch (target) { - case 'es5': - return 5; - case 'es6': - case 'es2015': - return 2015; - case 'es2017': - return 2017; - default: - throw new RangeError(`not supported ECMA version: ${target}`); - } -}; - -const getPackageFullName = (name) => { - switch (name) { - case 'common': - return `ONNX Runtime Common`; - case 'node': - return `ONNX Runtime Node.js Binding`; - case 'web': - return `ONNX Runtime Web`; - case 'react-native': - return `ONNX Runtime React-native`; - default: - throw new RangeError(`unknown package name: ${name}`); - } -}; - -/** - * Get package version by reading the file "package.json" under the package folder - * @param {'common'|'node'|'web'|'react-native'} name - the package name - * @returns a string representing the package version - */ -const getPackageVersion = (name) => { - const normalizedName = name.replace('-', '_'); - const packageJsonFileContent = readFileSync(resolve(__dirname, normalizedName, 'package.json')); - const packageJson = JSON.parse(packageJsonFileContent); - return packageJson.version; -}; - -/** - * - * @param {'development'|'production'} mode - specify webpack build mode - * @param {'common'|'node'|'web'|'react-native'} packageName - specify the name of the package - * @param {'es5'|'es6'|'es2015'|'es2017'} esVersion - specify the ECMAScript version - * @returns - */ -export const addCopyrightBannerPlugin = (mode, packageName, esVersion) => { - const COPYRIGHT_BANNER = `/*! - * ${getPackageFullName(packageName)} v${getPackageVersion(packageName)} - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. - */`; - - if (mode === 'production') { - // in 'production' mode, webpack uses terser to minimize the code. - // we set options.format.preamble to make sure terser generates correct copyright banner. - return new TerserPlugin({ - extractComments: false, - terserOptions: { - ecma: terserEcmaVersionFromWebpackEsVersion(esVersion), - format: { - preamble: COPYRIGHT_BANNER, - comments: false, - }, - compress: {passes: 2} - } - }); - } else { - // in 'development' mode, webpack does not minimize the code. - // we use the webpack builtin plugin BannerPlugin to insert the banner. - return new webpack.BannerPlugin({banner: COPYRIGHT_BANNER, raw: true}); - } -}; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 @@ -60,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc index 945c3aebce579..d0abf58922f88 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const { aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, dlpack_outputs.get()); for (size_t i = 0; i < output_size; ++i) { - ORT_RETURN_IF_ERROR( - p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + if (dlpack_outputs[i]) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index be9650d96b004..d72868cd8fa9f 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); +typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); + p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { - ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); + bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 4c9c15d07a9b8..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -82,6 +83,27 @@ struct PackedAttentionParameters { bool broadcast_res_pos_bias; }; +// Parameters deduced from node attributes and inputs/outputs. +struct GroupQueryAttentionParameters { + int batch_size; + int sequence_length; // sequence length of input query, key, value + int seqlen_past_kv_cache; // sequence length of past kv tensor + int seqlen_present_kv_cache; // sequence length of present kv tensor + int hidden_size; + int num_heads; + int head_size; + int kv_hidden_size; + int kv_num_heads; + int num_splits; // number of splits for splitkv + bool is_unidirectional; // causal + int local_window_size; + bool kv_share_buffer; + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + float scale; + AttentionQkvFormat qkv_format; + AttentionQkvFormat past_kv_format; +}; + namespace attention { // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804c61..694c40bf3eda6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -16,7 +16,6 @@ #include #include -#include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 1dc85e6d345d7..00e82c9844b3d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -103,7 +103,8 @@ Status CheckInputs(const T* query, } if (past_key_dims[2] != past_value_dims[2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)"); + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); } if (past_key_dims[3] != head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -205,6 +206,7 @@ Status CheckInputs(const T* query, } } + int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; @@ -215,13 +217,21 @@ Status CheckInputs(const T* query, } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(kv_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; } if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)"); + "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); } } @@ -256,7 +266,6 @@ Status CheckInputs(const T* query, } } - int total_sequence_length = past_sequence_length + kv_sequence_length; bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..47f462d75fcc4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int position_ids_format = parameters.position_ids_format; + const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); + + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; + + // Cache is (M, H/2) + const int position_id = (position_ids_format == 0) + ? static_cast(pos_ids_data[0]) + s + : static_cast(pos_ids_data[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache_data + cache_offset; + const T* sin_data = sin_cache_data + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < head_size; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..be834a66cdc69 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h new file mode 100644 index 0000000000000..7b2e8289f7b06 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", + input_dims.size()); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0ec5088808656..f9d9b13f0fedc 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -13,12 +13,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -27,6 +29,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -122,6 +126,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -244,12 +250,14 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -262,6 +270,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif @@ -295,6 +305,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 0000000000000..cb8e97a592d8c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 0000000000000..5ddb77e5b5ee3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..b898c956b6e6a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; + bool is_training_mode_; + bool transB_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + const bool transb = transB_; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..320a05bb97dac --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/narrow.h" +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_qnbit.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulNBits final : public OpKernel { + public: + MatMulNBits(const OpKernelInfo& info) + : OpKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + block_size_{narrow(info.GetAttr("block_size"))}, + nbits_{narrow(info.GetAttr("bits"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + const size_t K_; + const size_t N_; + const size_t block_size_; + const size_t nbits_; + const bool column_wise_quant_{true}; +}; + +Status MatMulNBits::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* b_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { + // number of bytes or elements between adjacent matrices + size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, + static_cast(K), static_cast(N), + b_data_matrix_stride_in_bytes, b_scale_matrix_stride, + &b_zero_point_matrix_stride_in_bytes); + + const size_t b_matrix_size = K * N; + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; + + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; + data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; + data[i].QuantBZeroPoint = zero_points_data != nullptr + ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes + : nullptr; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); + + return Status::OK(); + } + + const size_t ldb = helper.Ldb(true); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + zero_points_data, // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + +#if 0 // for debug + auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); +#endif + + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index e86a12d9fb873..4e103c2556a7a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -20,20 +20,29 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { +template +Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = p_ctx->Input(1); const Tensor* gamma = p_ctx->Input(2); @@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { } mean = mean / hidden_size; - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon_); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + } for (int64_t h = 0; h < hidden_size; h++) { - if (nullptr == beta_data) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; } else { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 7723541cb6b18..69edf4609e340 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { -template +template class SkipLayerNorm final : public OpKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..93cda00e5a3c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -52,28 +52,38 @@ namespace contrib { kCpuExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - transformers::BeamSearch); + transformers::BeamSearch); \ + \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + WhisperBeamSearch, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + transformers::WhisperBeamSearch); REGISTER_KERNEL_TYPED(float) namespace transformers { void BeamSearch::Init(const OpKernelInfo& info) { - parameters_.ParseFromAttributes(info); + parameters_->ParseFromAttributes(info); // Model_type could be either 0 (GPT-2), 1 (encoder-decoder like T5), or 2 (Whisper) - ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || - parameters_.model_type == IGenerationParameters::kModelTypeT5 || - parameters_.model_type == IGenerationParameters::kModelTypeWhisper); + ORT_ENFORCE(parameters_->model_type == IGenerationParameters::kModelTypeGpt || + parameters_->model_type == IGenerationParameters::kModelTypeT5 || + parameters_->model_type == IGenerationParameters::kModelTypeWhisper); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 and Whisper models. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -90,11 +100,11 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type == IGenerationParameters::kModelTypeGpt) { if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, - subgraph_session_state, parameters_); + subgraph_session_state, *parameters_); auto status = res.first; if (!status.IsOK()) { @@ -109,7 +119,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, // updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph // attribute. auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, - subgraph_session_state, parameters_); + subgraph_session_state, *parameters_); auto status = res.first; if (!status.IsOK()) { @@ -119,7 +129,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { + } else if (parameters_->model_type == IGenerationParameters::kModelTypeT5) { if (attribute_name == "encoder") { ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -129,7 +139,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state)); encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager(); - if (parameters_.decoder_start_token_id < 0) { + if (parameters_->decoder_start_token_id < 0) { ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); } else { @@ -144,12 +154,12 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, subgraph_session_state.GetGraphViewer()); ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state)); decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, - t5_decoder_subgraph_->num_heads, - t5_decoder_subgraph_->head_size, - t5_decoder_subgraph_->num_layers); + parameters_->SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, + t5_decoder_subgraph_->num_heads, + t5_decoder_subgraph_->head_size, + t5_decoder_subgraph_->num_layers); } - } else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + } else if (parameters_->model_type == IGenerationParameters::kModelTypeWhisper) { if (attribute_name == "encoder") { ORT_ENFORCE(whisper_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -169,10 +179,10 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, subgraph_session_state.GetGraphViewer()); ORT_RETURN_IF_ERROR(whisper_decoder_subgraph_->Setup(session_state, subgraph_session_state)); decoder_feeds_fetches_manager_ = whisper_decoder_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(whisper_decoder_subgraph_->vocab_size, - whisper_decoder_subgraph_->num_heads, - whisper_decoder_subgraph_->head_size, - whisper_decoder_subgraph_->num_layers); + parameters_->SetSubgraphParameters(whisper_decoder_subgraph_->vocab_size, + whisper_decoder_subgraph_->num_heads, + whisper_decoder_subgraph_->head_size, + whisper_decoder_subgraph_->num_layers); } } @@ -197,9 +207,9 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); // Make a copy of parameters since we will update it based on inputs later - BeamSearchParameters parameters = parameters_; + BeamSearchParameters parameters = *parameters_; - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters.model_type == IGenerationParameters::kModelTypeGpt) { if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32 BeamSearchGpt impl{ *ctx_internal, @@ -253,7 +263,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute."); ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); - if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { + if (parameters.model_type == IGenerationParameters::kModelTypeT5) { // Subgraph has constraint that the output is either float or float16 if (!t5_decoder_subgraph_->IsOutputFloat16()) { BeamSearchT5 impl{ @@ -303,7 +313,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { } // Change the CreateEncoderInputs function for Whisper shapes - if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + if (parameters.model_type == IGenerationParameters::kModelTypeWhisper) { // Subgraph has constraint that the output is either float or float16 if (!whisper_decoder_subgraph_->IsOutputFloat16()) { BeamSearchWhisper impl{ @@ -319,7 +329,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -340,7 +353,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_, expand_buffer_float16_func_, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -354,6 +370,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { ORT_THROW("Model type is not supported."); } +Status WhisperBeamSearch::Compute(OpKernelContext* ctx) const { + return BeamSearch::Compute(ctx); +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..fad7dcc75bcab 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -25,11 +25,12 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel class BeamSearch : public IControlFlowKernel { public: - BeamSearch(const OpKernelInfo& info) + BeamSearch(const OpKernelInfo& info, std::unique_ptr param = std::make_unique()) : IControlFlowKernel(info), encoder_feeds_fetches_manager_(nullptr), decoder_feeds_fetches_manager_(nullptr), dumper_(nullptr) { + parameters_.swap(param); Init(info); } @@ -88,12 +89,16 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, - const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { + const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; expand_buffer_int32_func_ = expand_buffer_int32_func; expand_buffer_float_func_ = expand_buffer_float_func; expand_buffer_float16_func_ = expand_buffer_float16_func; + update_decoder_cross_qk_func_ = update_decoder_cross_qk_func; + finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func; } #ifdef USE_CUDA @@ -101,7 +106,7 @@ class BeamSearch : public IControlFlowKernel { int cuda_device_arch_ = 0; #endif - private: + protected: // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; @@ -172,9 +177,21 @@ class BeamSearch : public IControlFlowKernel { IConsoleDumper* dumper_; - BeamSearchParameters parameters_; + std::unique_ptr parameters_; bool has_init_decoder_ = false; + + GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + + GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; +}; + +class WhisperBeamSearch : public BeamSearch { + public: + WhisperBeamSearch(const OpKernelInfo& info) + : BeamSearch(info, std::unique_ptr(new WhisperBeamSearchParameters())) {} + + Status Compute(OpKernelContext* ctx) const override; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 8832b4314bad3..29b38fc234de5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState { BeamSearchState(const IGenerationParameters& parameters, AllocatorPtr allocator, int has_decoder_masked_attention, - bool use_position) { + bool use_position, + Stream* stream) { size_t batch_beam_size = SafeInt(parameters.batch_size) * parameters.num_beams; size_t next_token_size = SafeInt(batch_beam_size) * parameters.vocab_size; - this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); - this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); - this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size); + this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, stream); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size, stream); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = SafeInt(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2; - this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size); + this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size, stream); if (allocator->Info().device.Type() == OrtDevice::GPU) { size_t sequences_elements = SafeInt(2) * batch_beam_size * parameters.max_length; - this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements); + this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements, stream); } if (use_position) { - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, stream); } - this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); + this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, stream); if (parameters.output_scores) { size_t elements = SafeInt(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size; - this->scores = AllocateBuffer(allocator, scores_buffer_, elements); + this->scores = AllocateBuffer(allocator, scores_buffer_, elements, stream); this->remaining_scores = this->scores; } @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState { } private: - BufferUniquePtr next_token_logits_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_indices_buffer_; - BufferUniquePtr next_scores_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr beam_scores_buffer_; - BufferUniquePtr scores_buffer_; - BufferUniquePtr topk_temp_buffer_; - BufferUniquePtr sequences_device_buffer_; + IAllocatorUniquePtr next_token_logits_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_indices_buffer_; + IAllocatorUniquePtr next_scores_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr beam_scores_buffer_; + IAllocatorUniquePtr scores_buffer_; + IAllocatorUniquePtr topk_temp_buffer_; + IAllocatorUniquePtr sequences_device_buffer_; }; struct BeamSearchCpuState : IBeamSearchCpuState { Sequences sequences; - BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda) + BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream) : parameters_{parameters} { - sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_); + sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_, stream); size_t sequences_bytes = SafeInt(2) * batch_beam_size_ * parameters.max_length; - sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */); + sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */); sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length); if (is_cuda) { // buffers used by CUDA operator but not by CPU operator. - topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_)); - topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_)); - topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_)); - final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_); + topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_), stream); + final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_, stream); + + size_t next_token_size = SafeInt(batch_beam_size_) * parameters.vocab_size; + next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); } } @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState { const IGenerationParameters& parameters_; const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams}; - BufferUniquePtr final_beam_scores_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr topk_scores_buffer_; - BufferUniquePtr topk_tokens_buffer_; - BufferUniquePtr topk_indices_buffer_; - BufferUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr final_beam_scores_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr topk_scores_buffer_; + IAllocatorUniquePtr topk_tokens_buffer_; + IAllocatorUniquePtr topk_indices_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; }; // Base class of beam search implementation that is common for GPT-2, T5, and Whisper. diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..56d950ca2f41e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -215,7 +215,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; // buffer in GPU for input_ids, position_ids and attention_mask IAllocatorUniquePtr buffer; @@ -240,7 +241,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchState beam_state{*parameters, this->temp_space_allocator_, gpt_subgraph_.has_decoder_masked_attention_, - true /* use_position */}; + true /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..94547887d3a90 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -144,7 +144,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -195,7 +196,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..91b93a125ad7a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -36,7 +36,9 @@ class BeamSearchWhisper : public BeamSearchBase { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, - const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func) + const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) : BeamSearchBase(context, decoder_session_state, thread_pool, ort_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -49,7 +51,11 @@ class BeamSearchWhisper : public BeamSearchBase { update_decoder_feeds_func_(update_decoder_feeds_func), expand_buffer_float_func_(expand_buffer_float_func), expand_buffer_float16_func_(expand_buffer_float16_func), - create_beam_scorer_func_(create_beam_scorer_func) {} + create_beam_scorer_func_(create_beam_scorer_func), + update_decoder_cross_qk_func_(update_decoder_cross_qk_func), + finalize_decoder_cross_qk_func_(finalize_decoder_cross_qk_func), + cuda_device_prop_(nullptr), + cuda_device_arch_(0) {} #ifdef USE_CUDA Status InitializeCuda( @@ -95,6 +101,8 @@ class BeamSearchWhisper : public BeamSearchBase { GenerationDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; }; @@ -122,6 +130,17 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); Tensor* output_scores = this->context_.Output(2, scores_shape); + if (parameters->no_speech_probs_output_id > 0) { + TensorShape no_speech_probs_shape{parameters->batch_size}; + Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); + if (no_speech_probs && no_speech_probs->MutableData()) { + ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, + "no_speech_token id out of range, it is ", parameters->no_speech_token, + ", vocab_size is ", parameters->vocab_size); + this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); + } + } + // Update the flag to indicate whether scores exists in output this->parameters_->output_scores = (output_scores != nullptr); @@ -136,7 +155,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -188,7 +208,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -222,6 +243,16 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe std::vector decoder_feeds; int current_length = parameters->sequence_length; + // for decoder subgraph output cross qk + int64_t frames_of_k = 0LL; + Tensor* cross_qk_output = nullptr; // output tensor + int64_t cross_qk_layer_head_pair_count = 0LL; + OrtValue cross_qk_buffer_value; + float* cross_qk_buffer_data = nullptr; + std::vector cross_qk_all_layer_heads; + const int32_t* cross_qk_layer_head_pairs = nullptr; + IAllocatorUniquePtr qk_layer_pointers; // if needed, device array hold the cross qk data pointers, shape of [num_layers] + std::vector decoder_fetches; if (current_length + 1 < parameters->max_length) { @@ -265,6 +296,41 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + ORT_ENFORCE(decoder_subgraph_.has_decoder_masked_attention_, "decoder subgraph: output_cross_qk could only work with has_decoder_masked_attention"); + ORT_ENFORCE(decoder_subgraph_.past_present_share_buffer_, "decoder subgraph: output_cross_qk could only work with past_present_share_buffer"); + + cross_qk_layer_head_pair_count = parameters->num_layers * parameters->num_heads; + const auto* input_tensor_cross_qk_layer_head = this->context_.template Input(parameters->cross_qk_layer_head_input_id); + ORT_ENFORCE(input_tensor_cross_qk_layer_head != nullptr, "Must specify input cross_qk_layer_head"); + cross_qk_layer_head_pair_count = input_tensor_cross_qk_layer_head->Shape()[0]; + cross_qk_layer_head_pairs = input_tensor_cross_qk_layer_head->template Data(); // it is on GPU + + size_t decoder_input_first_cross_key = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + (2 * decoder_subgraph_.num_layers); + auto first_cross_attention_key = decoder_feeds[decoder_input_first_cross_key].GetMutable(); + frames_of_k = first_cross_attention_key->Shape()[2]; + + TensorShape layer_cross_qk_shape{ + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->num_heads), + 1LL, + static_cast(frames_of_k)}; + for (int layer = 0; layer < decoder_subgraph_.num_layers; layer++) { + OrtValue cross_qk_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), layer_cross_qk_shape, this->temp_space_allocator_, cross_qk_value); + decoder_fetches.emplace_back(cross_qk_value); + } + + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + cross_qk_layer_head_pair_count, + static_cast(parameters->max_length), + frames_of_k}; + Tensor::InitOrtValue(DataTypeImpl::GetType(), cross_qk_shape, this->temp_space_allocator_, cross_qk_buffer_value); + cross_qk_buffer_data = cross_qk_buffer_value.GetMutable()->MutableData(); + } + if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same @@ -316,6 +382,21 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(status); + if (decoder_subgraph_.output_cross_qk_) { + int decoder_output_first_cross_qk = decoder_subgraph_.GetFirstPresentOutputIndex() + (2 * decoder_subgraph_.num_layers); + ORT_RETURN_IF_ERROR(this->update_decoder_cross_qk_func_( + iteration_counter, + this->ort_stream_, + &decoder_fetches[decoder_output_first_cross_qk], + qk_layer_pointers, + parameters->num_layers, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + cross_qk_buffer_data, + parameters->max_length, + this->temp_space_allocator_)); + } + #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { dumper->Print("decoder_fetches", i, true); @@ -383,6 +464,35 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_return_sequences), + cross_qk_layer_head_pair_count, + static_cast(iteration_counter - 1), + frames_of_k}; + cross_qk_output = this->context_.Output(parameters->cross_qk_output_id, cross_qk_shape); + + size_t cache_indir_input_offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast(decoder_subgraph_.num_layers) + 2; + const int* cache_indir_data = decoder_feeds[cache_indir_input_offset].GetMutable()->Data(); + auto beam_indices = this->beam_scorer_->GetNextIndicesGPU(); // currently only support on GPU + ORT_RETURN_IF_ERROR(this->finalize_decoder_cross_qk_func_( + this->ort_stream_, + iteration_counter, + parameters->sequence_length, + parameters->batch_size, + parameters->num_beams, + parameters->max_length, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + static_cast(frames_of_k), + cross_qk_buffer_data, + cross_qk_output->MutableData(), + parameters->num_return_sequences, + cache_indir_data, + beam_indices)); + } + gsl::span final_beam_scores = beam_state.beam_scores; this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 76011a5c89b66..3962486d5b5eb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -47,6 +47,23 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } batch_size = static_cast(dims[0]); + extra_decoding_ids = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && extra_decoding_ids_input_id > 0) { + const Tensor* extra_decoder_tensor = context->Input(extra_decoding_ids_input_id); + if (extra_decoder_tensor != nullptr) { + const auto& extra_decoder_tensor_dims = extra_decoder_tensor->Shape().GetDims(); + ORT_ENFORCE(extra_decoder_tensor_dims.size() == 2, + "extra_decoder_tensor shall have 2 dimensions. Got ", + extra_decoder_tensor_dims.size()); + ORT_ENFORCE(extra_decoder_tensor_dims[0] == batch_size, + "extra_decoder_tensor first dim not same as batch_size. Got ", + extra_decoder_tensor_dims[0], ", expecting ", batch_size); + if (extra_decoder_tensor->Shape().Size() > 0) { + extra_decoding_ids = gsl::span(extra_decoder_tensor->Data(), (size_t)extra_decoder_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { @@ -119,6 +136,18 @@ void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, num_layers = layers; } +void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { + BeamSearchParameters::ParseFromAttributes(info); + model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); + ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); + + no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + cross_qk_layer_head_input_id = 12; + extra_decoding_ids_input_id = 13; + cross_qk_output_id = 3; + no_speech_probs_output_id = 4; +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 0cb2b39976cc3..87bc6cdbfe72c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -11,17 +11,23 @@ namespace contrib { namespace transformers { struct BeamSearchParameters : public IGenerationParameters { + virtual ~BeamSearchParameters() {} + Status Validate() const; int BatchBeamSize() const { return batch_size * num_beams; } - void ParseFromAttributes(const OpKernelInfo& info); + virtual void ParseFromAttributes(const OpKernelInfo& info); void ParseFromInputs(OpKernelContext* context); void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers); }; +struct WhisperBeamSearchParameters : public BeamSearchParameters { + void ParseFromAttributes(const OpKernelInfo& info) override; +}; + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index e889281abb023..680cb23fd887a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -33,24 +33,43 @@ gsl::span AllocateBuffer(AllocatorPtr allocator, return span; } +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + IAllocatorUniquePtr& buffer, + size_t elements, + Stream* stream, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + template inline void AllocateTempBufferForGetGreedySearchTopOne( int32_t batch_size, AllocatorPtr allocator, - BufferUniquePtr& buffer, + IAllocatorUniquePtr& buffer, gsl::span& stage_1_scores, // shape (batch_size, parts_of_vocab) gsl::span& stage_1_tokens, // shape (batch_size, parts_of_vocab) gsl::span& output_scores, // shape (batch_size) - gsl::span& output_tokens // shape (batch_size) -) { + gsl::span& output_tokens, // shape (batch_size) + Stream* stream) { constexpr size_t kMaxPartsPerVocab = 128; const size_t stage_1_element_size = kMaxPartsPerVocab * batch_size; const size_t output_element_size = batch_size; // Note: use float to allocate buffer for temporary value buffer to avoid unalignment - void* topk_data = allocator->Alloc((stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t))); - BufferUniquePtr temp_buffer(topk_data, BufferDeleter(allocator)); - buffer = std::move(temp_buffer); + size_t bytes = (stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t)); + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + void* topk_data = buffer.get(); ElementType* stage_1_scores_data = reinterpret_cast(topk_data); stage_1_scores = gsl::make_span(stage_1_scores_data, stage_1_element_size); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 88348ad88dc27..927d3a58e5a6f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -1084,6 +1084,38 @@ template Status CreateWhisperEncoderInputs( OrtValue& encoder_input_features, OrtValue& decoder_input_ids); +Status UpdateDecoderCrossQK( + [[maybe_unused]] int iteration_number, + [[maybe_unused]] Stream* tream, + [[maybe_unused]] OrtValue* cross_qks, + [[maybe_unused]] IAllocatorUniquePtr& qk_layer_pointers, + [[maybe_unused]] int num_layers, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] float* cross_qk_buffer_data, + [[maybe_unused]] int max_length, + [[maybe_unused]] AllocatorPtr allocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CPU beam search current not support output cross QK."); +} + +Status FinalizeDecoderCrossQK( + [[maybe_unused]] Stream* stream, + [[maybe_unused]] int iteration_number, + [[maybe_unused]] int context_decoding_len, + [[maybe_unused]] int batch_size, + [[maybe_unused]] int num_beams, + [[maybe_unused]] int max_length, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] int frames_of_k, + [[maybe_unused]] const float* cross_qk_buffer_data, + [[maybe_unused]] float* cross_qk_output, + [[maybe_unused]] int num_return_sequences, + [[maybe_unused]] const int* cache_indir_data, + [[maybe_unused]] gsl::span beam_indices) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CPU beam search current not support output cross QK."); +} + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..6dfdc6b027671 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -204,6 +204,35 @@ using ExpandBufferFunc = std::function; + +using UpdateDecoderCrossQKFunc = std::function& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator)>; + +using FinalizeDecoderCrossQKFunc = std::function beam_indices)>; + } // namespace GenerationDeviceHelper // These are CPU specific device helper implementations @@ -368,6 +397,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 719dd302d274d..f6faf2e325f8f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -53,6 +53,7 @@ struct IBeamSearchCpuState { gsl::span topk_tokens; // shape (batch_size, 2*num_beams), tokens of topk candidates. gsl::span topk_indices; // shape (batch_size, 2*num_beams), beam indices of topk candidates. gsl::span final_beam_scores; // shape (batch_size, num_beams) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) }; template @@ -175,6 +176,17 @@ struct IGenerationParameters { int seed = 0; int min_tokens_to_keep = 1; bool custom_sampling = false; + + // Parameters for whisper model + bool decoder_output_cross_qk = false; + gsl::span extra_decoding_ids; + int32_t no_speech_token = -1; + void* no_speech_probs = nullptr; + + int cross_qk_layer_head_input_id = -1; + int extra_decoding_ids_input_id = -1; + int cross_qk_output_id = -1; + int no_speech_probs_output_id = -1; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index be974ed2159d9..9f372e5b3a673 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -20,26 +20,27 @@ struct SamplingState : public ISamplingState { int vocab_size, int max_iter, int seed, - bool is_cuda) { + bool is_cuda, + Stream* stream) { int total_count = batch_size * vocab_size; - this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count), stream); this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; if (is_cuda) { - this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); - this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); - this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); - this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); - this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); - this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); - this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count), stream); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count), stream); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1), stream); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count), stream); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size), stream); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter), stream); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size), stream); this->temp_storage_bytes = 0; // TODO: Do not allocate this buffer if there's no presence_mask - this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count), stream); std::uniform_real_distribution distribution(0.0, 1.0); static_cast(distribution(this->generator)); @@ -48,25 +49,25 @@ struct SamplingState : public ISamplingState { } } else { // TODO: Some buffer can be reused for CPU - this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); - this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count), stream); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count), stream); } } private: - BufferUniquePtr d_index_in_buffer_; - BufferUniquePtr d_index_out_buffer_; - BufferUniquePtr d_offset_buffer_; - BufferUniquePtr d_sorted_score_buffer_; - BufferUniquePtr d_sorted_softmaxed_score_buffer_; - BufferUniquePtr d_softmaxed_score_buffer_; - BufferUniquePtr h_softmaxed_score_buffer_; - BufferUniquePtr d_sampled_buffer_; - BufferUniquePtr h_sampled_all_buffer_; - BufferUniquePtr d_indices_buffer_; - BufferUniquePtr d_presence_mask_buffer_; - BufferUniquePtr sorted_scores_buffer_; - BufferUniquePtr cumulative_probs_buffer_; + IAllocatorUniquePtr d_index_in_buffer_; + IAllocatorUniquePtr d_index_out_buffer_; + IAllocatorUniquePtr d_offset_buffer_; + IAllocatorUniquePtr d_sorted_score_buffer_; + IAllocatorUniquePtr d_sorted_softmaxed_score_buffer_; + IAllocatorUniquePtr d_softmaxed_score_buffer_; + IAllocatorUniquePtr h_softmaxed_score_buffer_; + IAllocatorUniquePtr d_sampled_buffer_; + IAllocatorUniquePtr h_sampled_all_buffer_; + IAllocatorUniquePtr d_indices_buffer_; + IAllocatorUniquePtr d_presence_mask_buffer_; + IAllocatorUniquePtr sorted_scores_buffer_; + IAllocatorUniquePtr cumulative_probs_buffer_; }; template @@ -82,24 +83,25 @@ struct GreedySearchState : public IGreedySearchState { int num_heads, int head_size, bool has_decoder_masked_self_attention, - bool is_cuda) { + bool is_cuda, + Stream* stream) { // below buffers are on cpu this->sequences_space = AllocateBuffer(cpu_allocator, sequences_space_buffer_, - SafeInt(2) * batch_size * max_length); + SafeInt(2) * batch_size * max_length, stream); memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes()); this->sequences.Init(this->sequences_space, static_cast(batch_size), sequence_length, max_length); - this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size); - this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size); + this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size, stream); + this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size, stream); memset(this->eos_meet.data(), 0, this->eos_meet.size_bytes()); - this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size)); + this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size), stream); // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size, stream); if (is_cuda) { AllocateTempBufferForGetGreedySearchTopOne( @@ -109,7 +111,8 @@ struct GreedySearchState : public IGreedySearchState { this->temp_topk_scores_buffer, this->temp_topk_tokens_buffer, this->topk_scores_buffer, - this->topk_tokens_buffer); + this->topk_tokens_buffer, + stream); // If at all we need to, we only need to re-order past state for CUDA as //`DecoderMaskedSelfAttention` is only supported on CUDA @@ -137,14 +140,14 @@ struct GreedySearchState : public IGreedySearchState { } private: - BufferUniquePtr sequences_space_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr eos_meet_buffer_; - BufferUniquePtr temp_topk_buffer_; - BufferUniquePtr staging_for_past_state_reorder_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr eos_meet_buffer_; + IAllocatorUniquePtr temp_topk_buffer_; + IAllocatorUniquePtr staging_for_past_state_reorder_buffer_; }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..69d25eaabbe02 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -211,7 +211,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->num_heads), static_cast(parameters->head_size), gpt_subgraph_.has_decoder_masked_attention_, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); SamplingState sampling_state; if (std::is_same::value) { @@ -221,7 +222,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->vocab_size), static_cast(parameters->max_length - parameters->sequence_length), parameters->seed, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); } IAllocatorUniquePtr buffer; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h index f1150fdba00e8..4ef0c180eba34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h @@ -13,7 +13,7 @@ namespace transformers { struct GreedySearchParameters : public BeamSearchParameters { int BatchBeamSize() const { return batch_size; } - void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromAttributes(const OpKernelInfo& info) override; void ParseFromInputs(OpKernelContext* context); }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 9f77c32f0c7cc..f39f090c78b0c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,20 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -template -gsl::span NextTokenScores::GetScores(int batch_beam_index) { - assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); - return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); -} - -template -void NextTokenScores::SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); - for (int i = 0; i < batch_beam_size; i++) { - scores[static_cast(i) * vocab_size + token_id] = score; - } -} - #ifdef DEBUG_GENERATION template void DumpScores(const char* name, const NextTokenScores& next_token_scores) { @@ -238,128 +224,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } -template -TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} - -template -void TimestampLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - - const int batch_beam_size = next_token_scores.batch_beam_size; - const int vocab_size = next_token_scores.vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.GetScores(i); - gsl::span sequence = sequences->GetSequence(i); - const size_t seq_length = sequence.size(); - - // Find first timestamp - size_t sample_begin = 0; - for (size_t j = 0; j < seq_length; j++) { - sample_begin++; - if (sequence[j] >= beg_token_id_) { - break; - } - } - - // Suppress tokens - for (int j = 0; j < vocab_size; j++) { - // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - - // Suppress sot, translate and transcribe tokens - if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } else { - // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Find timestamp tokens - std::vector timestamps; - for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { - timestamps.push_back(word_id); - } - } - - // Timestamps will not decrease - const size_t timestamps_len = timestamps.size(); - if (timestamps_len > 0) { - int timestamp_last = 0; - if (last_was_timestamp && !penultimate_was_timestamp) { - // For single timestamp at the end, next timestamp must not be smaller - timestamp_last = timestamps.back(); - } else { - // For paired timestamp at the end, next timestamp must be greater - timestamp_last = timestamps.back() + 1; - } - - for (int j = beg_token_id_; j < timestamp_last; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; - for (int j = last_allowed + 1; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - // Caculate logsumexp on timestamps - float timestamp_logprob = std::numeric_limits::lowest(); - { - float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { - if (beam_token_scores[j] > std::numeric_limits::lowest()) { - logsumexp += expf(beam_token_scores[j] - logprob_max); - } - } - if (logsumexp > 0.0f) { - timestamp_logprob = logf(logsumexp) + logprob_max; - } - } - - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); - if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif -} - void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 664c497a106d4..4688ff272cee9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" @@ -20,9 +21,17 @@ struct NextTokenScores { int batch_beam_size; int vocab_size; - gsl::span GetScores(int batch_beam_index); + gsl::span GetScores(int batch_beam_index) { + assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); + return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); + } - void SetScore(int token_id, T score); + void SetScore(int token_id, T score) { + assert(token_id >= 0 && token_id < vocab_size); + for (int i = 0; i < batch_beam_size; i++) { + scores[static_cast(i) * vocab_size + token_id] = score; + } + } }; // Interface for all scorers for beam search or beam sample. @@ -141,10 +150,126 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index); + TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores) override { + // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. + const int beg_token_id_ = eos_token_id_ + 107; + const int not_token_id_ = eos_token_id_ + 106; + const int solm_token_id_ = eos_token_id_ + 105; + const int sot_token_id_ = eos_token_id_ + 1; + constexpr int translate_token_id_ = 50358; + constexpr int transcribe_token_id_ = 50359; + + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + const size_t seq_length = sequence.size(); + + // Find first timestamp + size_t sample_begin = 0; + for (size_t j = 0; j < seq_length; j++) { + sample_begin++; + if (sequence[j] >= beg_token_id_) { + break; + } + } + + // Suppress tokens + for (int j = 0; j < vocab_size; j++) { + // Suppress notimestamps and solm tokens + if (j == not_token_id_ || j == solm_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + + // Suppress sot, translate and transcribe tokens + if (seq_length > sample_begin) { + if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Timestamps should be in pair except the first one + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated + for (int j = beg_token_id_; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } else { + // If timestamp doesn't show up in pair, generate timestamp + for (int j = 0; j < eos_token_id_; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Find timestamp tokens + std::vector timestamps; + for (const auto& word_id : sequence) { + if (word_id >= beg_token_id_) { + timestamps.push_back(word_id); + } + } + + // Timestamps will not decrease + const size_t timestamps_len = timestamps.size(); + if (timestamps_len > 0) { + int timestamp_last = 0; + if (last_was_timestamp && !penultimate_was_timestamp) { + // For single timestamp at the end, next timestamp must not be smaller + timestamp_last = timestamps.back(); + } else { + // For paired timestamp at the end, next timestamp must be greater + timestamp_last = timestamps.back() + 1; + } + + for (int j = beg_token_id_; j < timestamp_last; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + if (seq_length == sample_begin) { + const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + for (int j = last_allowed + 1; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + // Caculate logsumexp on timestamps + float timestamp_logprob = std::numeric_limits::lowest(); + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); + for (int j = beg_token_id_; j < vocab_size; ++j) { + if (beam_token_scores[j] > std::numeric_limits::lowest()) { + logsumexp += expf(beam_token_scores[j] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + if (timestamp_logprob > max_text_token_logprob) { + for (int j = 0; j < beg_token_id_; ++j) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_GENERATION + DumpScores("TimestampLogitsProcessor", next_token_scores); +#endif + } private: int eos_token_id_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h index af203abc15d01..7779a73feffba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -11,7 +11,7 @@ namespace contrib { namespace transformers { struct SamplingParameters : public GreedySearchParameters { - void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromAttributes(const OpKernelInfo& info) override; void ParseFromInputs(OpKernelContext* context); }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 3c11d2d324a85..487a35c55a85f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -45,6 +45,7 @@ class Subgraph { int num_layers; bool past_present_share_buffer_; bool has_decoder_masked_attention_; + bool output_cross_qk_ = false; // Setup execution Status Setup(const SessionState& session_state, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 28acd81ae95fd..4d61ce71c69be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,7 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 085d8f3903976..83dae49c7dcbd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -5,6 +5,7 @@ #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/transformers/sequences.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { namespace contrib { @@ -20,6 +21,13 @@ class T5DecoderSubgraph : public Subgraph { has_hidden_state_(false), use_sequence_as_input_ids_(true) { first_present_output_index_ = 1; + + // Currently just using parent node's attribute. Maybe better to find it purely in subgraph. + const auto& attributes = node_in.GetAttributes(); + if (attributes.find("decoder_output_cross_qk") != attributes.end()) { + auto& attr = attributes.at("decoder_output_cross_qk"); + output_cross_qk_ = (attr.i() != 0LL); + } } // Create inputs for first inference of decoder subgraph. @@ -62,7 +70,7 @@ class T5DecoderSubgraph : public Subgraph { return first_present_output_index_; } - bool UseSequenceAsInputIds() const { + inline bool UseSequenceAsInputIds() const { return use_sequence_as_input_ids_; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index 887a6a8984b83..7d0c62b618ee2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -70,8 +70,15 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr "number of inputs expected to be kFirstPastInputIndex + 4 * layers + 1, got:", num_subgraph_inputs); } - ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, - "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); + if (!output_cross_qk_) { + ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, + "number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 2 * layers, got:", num_subgraph_outputs); + } else { + ORT_RETURN_IF(num_subgraph_outputs < 4 || (num_subgraph_outputs - first_present_output_index_) % 3 != 0, + "When outputing cross qk, number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 3 * layers, got:", num_subgraph_outputs); + } ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); @@ -90,7 +97,8 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr // Save parameters related to the subgraph. ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); - num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 2; + + num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / (output_cross_qk_ ? 3 : 2); // If input_ids's shape is ['batch_size', 1] then use next token as input_ids. // Otherwise in the case of shape ['batch_size', 'sequence'], use sequence as input_ids. @@ -112,12 +120,7 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states"); - } - - for (int i = 0; i < num_subgraph_outputs; i++) { - ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph output shall have same data type as that of encoder_hidden_states"); + "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states."); } is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type); @@ -166,7 +169,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index d846f55f1e28d..626e4c0b87a3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o T* k_smem = q_smem + rotary_embedding_dim; const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (head_idx) / half_rotary_dim; - const int intra_half_idx = (head_idx) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; + const int half_idx = (head_idx) / half_rotary_dim; + const int intra_half_idx = (head_idx) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; if (do_rotary) { *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; @@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co } } - template __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) { // Format 3 for cutlass memory efficient attention @@ -651,7 +650,7 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (qk_head_size != 64 && qk_head_size !=128) { + if (qk_head_size != 64 && qk_head_size != 128) { ORT_THROW("qk_head_size must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..bf6431cf1afb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,32 +135,52 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); } - - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); } } else { // BERT bool use_fused_runner = !disable_fused_self_attention_ && @@ -175,8 +195,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -213,11 +235,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - IAllocatorUniquePtr gemm_buffer; + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -244,7 +267,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); + ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; @@ -271,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 455e55ba05a66..acafb379d713f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" @@ -28,6 +29,7 @@ class Attention final : public CudaKernel, public AttentionBase { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb1..83c426e7e6ed7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); @@ -372,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -393,10 +396,12 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d6..3e78978c3cc43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu index 2af748d8d4a62..32ed961a68049 100644 --- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu +++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu @@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK) const int* attention_masks, const int batch_size, const int sequence_length) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int batch_id = blockIdx.x; - const int* batch_mask = attention_masks + (batch_id * sequence_length); - const bool leftmost_non_zero = (batch_mask[0] != 0); - int biggest_position = 0; - - for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { - if (leftmost_non_zero == (batch_mask[i] != 0)) { - biggest_position = i; - } else { - break; - } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int batch_id = blockIdx.x; + const int* batch_mask = attention_masks + (batch_id * sequence_length); + const bool leftmost_non_zero = (batch_mask[0] != 0); + int biggest_position = 0; + + for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { + if (leftmost_non_zero == (batch_mask[i] != 0)) { + biggest_position = i; + } else { + break; } + } - int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); + int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); - if (threadIdx.x == 0) { - int batch_offset = batch_id * sequence_length; - trt_mha_padding_offset[2 * batch_id] = batch_offset; - trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; - if (batch_id == gridDim.x - 1) { - trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; - } + if (threadIdx.x == 0) { + int batch_offset = batch_id * sequence_length; + trt_mha_padding_offset[2 * batch_id] = batch_offset; + trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; + if (batch_id == gridDim.x - 1) { + trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; } + } } // only support simple left padding with mask 0s on leading left, diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca332..db78722cc0e4c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -16,6 +16,133 @@ namespace onnxruntime { namespace contrib { namespace cuda { +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + template void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { using Attention = AttentionKernel; @@ -51,28 +178,52 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + } + + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; } - constexpr auto kernel_fn = attention_kernel_batched_impl; int smem_bytes = sizeof(typename Attention::SharedStorage); if (smem_bytes > 0xc000) { ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf89..484b783db1724 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; @@ -41,6 +43,8 @@ struct MemoryEfficientAttentionParams { static bool need_workspace(size_t v_head_size, bool is_float) { return (v_head_size > 128 && !is_float); } + + bool has_custom_right_padding = false; }; void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 4bdc6db30b036..54aad9cbaf387 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8; static constexpr int kCacheIndirectionInputIndex = 9; static constexpr int kPastInputIndex = 5; static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; static constexpr int kBiasIndex = 10; #define REGISTER_KERNEL_TYPED(T1, T2) \ @@ -50,6 +51,7 @@ DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const O mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); } template @@ -98,7 +100,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // This kernel is for decoding only (i.e.) sequence length has to be 1 if (sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. Actual length is ", sequence_length); } if (parameters.head_size != parameters.v_head_size) { @@ -125,6 +127,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* TensorShape present_shape(present_dims); Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* cross_qk = nullptr; auto cuda_stream = Stream(context); @@ -191,6 +194,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.v_cache = present_value_data; } + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + cross_qk = context->Output(kQKOutputIndex, qk_shape); + parameters.out_qk = cross_qk->MutableData(); + } + parameters.out = output->MutableDataRaw(); // Scale diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h index 8200a66db383f..b5476e6b54c44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h @@ -22,6 +22,7 @@ class DecoderMaskedMultiHeadAttention final : public CudaKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_; + bool output_qk_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index a2dfca8cd6f09..ae53eca541fa5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_ } inline Status ComputeMaskIndex(cudaStream_t stream, - const int sequence_length, - const int batch_size, - const int* mask, - int* mask_index) { + const int sequence_length, + const int batch_size, + const int* mask, + int* mask_index) { // Mask idx is of length batch_size and assumes the valid region is contiguous starting // from the beginning of the sequence @@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel( } if (nullptr == position_ids) { position_id = blockIdx.x; - } else if (broadcast_position_ids){ + } else if (broadcast_position_ids) { position_id = position_ids[sequence_position % gridDim.x]; } else { position_id = position_ids[sequence_position]; @@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel( void* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { - if (mask_index != nullptr) { if (nullptr == input_mask) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)); } else { ORT_RETURN_IF_ERROR( - ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); + ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu index 1b0de47a834ec..c9498eb1bcd7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu @@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const float* input, const float* bias, float* output, bool /*use_half2*/) { + const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; FastGeluKernel<<>>(A, B, C, input_length, bias_length, @@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const half* input, const half* bias, half* output, bool use_half2) { + const half* input, const half* bias, half* output, bool use_half2) { constexpr int blockSize = 256; if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) { const int n = input_length / 2; @@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { + const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; // remove nv_bfloat162 implementation for now to fix build issue diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index c8877a5e3f872..33e7a33494778 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -427,6 +427,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Compute the logits and start the sum. float sum = 0.f; int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength; + + if (params.out_qk != nullptr) { + // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] + float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { + target[ti] = (float)(qk_smem[ti]); + } + } + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { // This is a deviation from FasterTransformer kernel implementation // but this aligns with ORT's other Attention kernels which strives to diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 6d7f368db4dd4..4b408dafa2d81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,6 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_cache = nullptr; void* out = nullptr; + void* out_qk = nullptr; const int32_t* cache_indir = nullptr; const int32_t* mask = nullptr; // [B, total_sequence_length] diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h index 9db98061bbd66..811b1be7d4315 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -12,9 +12,13 @@ struct BlockInfo { template __device__ BlockInfo(const Params& params, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), - sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), - actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q), - actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) { + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } template @@ -30,6 +34,8 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; const int actual_seqlen_k; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 9394a19c9897a..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -18,68 +18,112 @@ constexpr int D_DIM = 2; struct Qkv_params { using index_t = uint32_t; // The QKV matrices. - void* __restrict__ q_ptr; - void* __restrict__ k_ptr; - void* __restrict__ v_ptr; + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; // The number of heads. - int h, h_k; + int h = 0; + int h_k = 0; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, + int h_h_k_ratio = 0; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { // The O matrix (output). - void* __restrict__ o_ptr; + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; // The pointer to the P matrix. - void* __restrict__ p_ptr; + void* __restrict__ p_ptr = nullptr; // The pointer to the softmax sum. - void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; // array of length b+1 holding starting offset of each sequence. - int* __restrict__ cu_seqlens_q; - int* __restrict__ cu_seqlens_k; + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; - int* __restrict__ blockmask; + int* __restrict__ blockmask = nullptr; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; bool is_bf16 = false; - bool is_causal; + bool is_causal = false; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; - const cudaDeviceProp* dprops; + int num_splits = 0; // For split-KV version + + const cudaDeviceProp* dprops = nullptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); } // namespace flash -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 87831d1eddfe9..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -34,24 +34,39 @@ void set_params_fprop(Flash_fwd_params& params, void* p_d, void* softmax_lse_d, float softmax_scale, - bool is_causal) { + bool is_causal, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; params.v_ptr = v; params.o_ptr = out; - // All stride are in elements, not bytes. - params.q_row_stride = num_heads * head_size; - params.k_row_stride = num_heads_k * head_size; - params.v_row_stride = num_heads * head_size; - params.q_head_stride = head_size; - params.k_head_stride = head_size; - params.v_head_stride = head_size; - params.o_row_stride = num_heads * head_size; - params.o_head_stride = head_size; params.is_bf16 = false; + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) @@ -89,7 +104,22 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; } size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { @@ -97,14 +127,104 @@ size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { return bytes; } -void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } }); }); } +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -119,14 +239,18 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, - bool is_causal) { + bool is_causal, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, @@ -139,7 +263,28 @@ Status mha_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } run_mha_fwd(params, stream); return Status::OK(); @@ -168,7 +313,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, max_seqlen_q, max_seqlen_k, @@ -181,7 +325,16 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; run_mha_fwd(params, stream); return Status::OK(); } @@ -192,6 +345,103 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); } +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k != nullptr && v != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k; + params.vnew_ptr = v; + // All stride are in elements, not bytes. + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + + return Status::OK(); +} + } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 2ae46d34c373a..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,9 +31,11 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -48,7 +50,12 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, - bool is_causal); + bool is_causal, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -68,8 +75,36 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal); +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); + size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); + bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index b5af31e432d42..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -79,7 +38,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T flash::reduce_sum(scores, scores_sum); } else { cute::Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); + cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); @@ -109,7 +68,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T template inline __device__ void write_softmax_to_gmem( - cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_thr_copy_P) { + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) cute::Layout l = tOrP.layout(); cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); @@ -117,13 +76,13 @@ inline __device__ void write_softmax_to_gmem( CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,9 +103,50 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= n_block_min) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } } // We iterate over the blocks in reverse order. This is because the last block is the only one @@ -158,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -293,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -325,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -351,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -369,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -392,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -402,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -504,7 +510,546 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + Tensor tQrQ = make_fragment_like(tQgQ); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -520,10 +1065,194 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e633ef4d45fbb..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,9 +10,33 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn_splitkv(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif } template @@ -25,35 +49,100 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); }); }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -61,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -97,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -124,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -164,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -185,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000000..68ae2ea759813 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000000..94564a6aba8f3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000000..ec9e9e738c5b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 0000000000000..e6c4ff5d95584 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000000..552966852cdbe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000000..e9f191a4828d6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000000..d628a556680ad --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000000..88b6cc0fb1e22 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 0c967faa85c45..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -111,7 +111,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; @@ -139,18 +140,35 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy>; using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + cute::Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store + cute::Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -289,13 +307,13 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; + // static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 49ee687419d0e..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -96,46 +96,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2(&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162(&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MaxOp { __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } @@ -245,7 +205,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -261,9 +224,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -338,9 +305,9 @@ CUTE_HOST_DEVICE void cp_async_wait() { template -inline __device__ void copy(TiledCopy thr_copy, Tensor const& S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -354,13 +321,176 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const& #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { - clear(D(_, m, k)); + cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { - clear(D(_, m, _)); + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc new file mode 100644 index 0000000000000..93892169f6c79 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention.h" +#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : CudaKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); +#else + disable_flash_attention_ = true; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif +} + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* seqlens_k = context->Input(5); + const Tensor* total_seqlen = context->Input(6); + + auto& device_prop = GetDeviceProp(); + GroupQueryAttentionParameters parameters; + typedef typename ToCudaType::MappedType CudaT; + GroupQueryAttentionData data; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + device_prop.maxThreadsPerBlock)); + parameters.local_window_size = local_window_size_; + int sequence_length = parameters.sequence_length; + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = context->Output(0, output_shape); + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.kv_num_heads); + // Allocate buffers + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + // softmax buffer + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + local_window_size_ == -1 && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + // seqlens_k buffer + size_t seqlens_k_bytes = 0; + seqlens_k_bytes = sizeof(int) * parameters.batch_size; + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); + + data.query = reinterpret_cast(query->Data()); + data.key = reinterpret_cast(key->Data()); + data.value = reinterpret_cast(value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); + data.output = reinterpret_cast(output->MutableData()); + data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.seqlens_k = const_cast(seqlens_k->Data()); + data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; + if (data.past_key == data.present_key) { + parameters.kv_share_buffer = true; + } else { + parameters.kv_share_buffer = false; + } + // Flash Buffers + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } + if (seqlens_k_buffer != nullptr) { + data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + } + // Memory Efficient Buffers + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + + cublasHandle_t cublas = GetCublasHandle(context); + + return QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h new file mode 100644 index 0000000000000..54a8127e29e7b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_kernel.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GroupQueryAttention final : public CudaKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_past_bsnh_; + float scale_; + bool disable_flash_attention_; + bool disable_memory_efficient_attention_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h new file mode 100644 index 0000000000000..2cb9955807f26 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace group_query_attention_helper { + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + bool is_past_bsnh, + float scale) { + // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, S, D_kv) + // value (V) : (B, S, D_kv) + ORT_UNUSED_PARAMETER(value); + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; + + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = static_cast(q_hidden_size) / num_heads; + + int kv_hidden_size = static_cast(key_dims[2]); + + int32_t past_sequence_length = 0; + if (past_key != nullptr && past_value != nullptr) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + // BNSH + if (!is_past_bsnh) { + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + // BSNH + } else { + if (past_key_dims[1] != past_value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[1]); + } + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + } else if (past_key != nullptr || past_value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall be both present or both absent."); + } + + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } + + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } + + if (static_cast(sequence_length) != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k must be shape (batch_size)."); + } + + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length tensor must be of one element."); + } + int total_sequence_length = *((*total_seqlen).template Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + bool is_prompt = sequence_length != 1; + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return Status::OK(); +} + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + bool is_past_bsnh, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + + return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); +} + +} // namespace group_query_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu new file mode 100644 index 0000000000000..b22ccb68c1e7b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -0,0 +1,718 @@ +/* + The implementation of this file is based on our Multi-Head Attention impl.cu file, + which is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: +// (1) support GPT-2 past state, unidirectional mask (causal) +// (2) use flash attention kernel from (https://github.com/Dao-AILab/flash-attention) +// (3) support different number of heads for Q and KV +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" +#include "contrib_ops/cuda/bert/transformer_common.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cuda/bert/bert_padding.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels for KV prep + +// Kernel for seqlens_k +__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id < batch_size) seqlens_k[id] = seqlen; +} + +// Kernel to append new and past kv in either BSNH or BNSH format +// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file +template +__global__ void ConcatNewToPastKV(const int new_seqlen, + const int past_buffer_seqlen, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to past; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_buffer_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } +} + +// Use when (H*)*num_heads > 1024 +template +__global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int past_buffer_seqlen, + const int H, + const int num_heads, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_buffer_seqlen = gridDim.y; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } + } +} + +// Concat new to past in present. Supports past BSNH or past BNSH +template +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int past_sequence_length = parameters.seqlen_past_kv_cache; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int max_seqlen, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int max_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + // Indicates past sequence_length of each sequence + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + + +__global__ void PastToTotalSeqlen(int32_t* seqlens_k, + int32_t* seqlens_k_buff, + const int add_seqlen) { + seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; +} + +// Convert Past to Total sequence length tensor +Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, + const int threads_per_block) { + if (parameters.is_prompt) { + return Status::OK(); + } + const int batch_size = parameters.batch_size; + const int add_seqlen = is_total ? parameters.sequence_length : 0; + + const dim3 grid(1, 1, 1); + // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads + const dim3 block(batch_size, 1, 1); + + // TODO(aciddelgado): small version + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); + + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = true; + + // Note: seqlens_k is past sequence length for flash + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } + + void* seqlens_k = reinterpret_cast(data.seqlens_k); + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + + if (parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + key = nullptr; + value = nullptr; + seqlens_k = reinterpret_cast(data.seqlens_k_total); + } + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + seqlens_k, batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else { + // Not share buffer case + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + if (!parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + seqlens_k = reinterpret_cast(data.seqlens_k_total); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); + DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); + DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), + seqlens_k, batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, 0, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } else { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + // Concatenate new kv in place + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + // Not share buffer case + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + // Copy past and concat new KV to present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + } + + // Ungroup if grouped, otherwise use present kv directly + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.max_sequence_length = present_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = true; + p.scale = scale; + p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + p.has_custom_right_padding = true; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); +} + +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h new file mode 100644 index 0000000000000..de32d7ea93163 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +struct GroupQueryAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + int* seqlens_k = nullptr; + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; + int* seqlens_k_total = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors + T* output = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + // Kernel Flags + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; +}; + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index 5c083d64ee542..ff3178b56c2a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -147,14 +147,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< __shared__ T rsigma; // 1 / std.dev. T beta_v[ILP], gamma_v[ILP], output_v[ILP]; - if (beta != nullptr) { - VecT* beta_val = reinterpret_cast(&beta_v); - *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); - } - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; + if (is_valid) { + if (beta != nullptr) { + VecT* beta_val = reinterpret_cast(&beta_v); + *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); + } - VecT* output_val = reinterpret_cast(&output_v); + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } KeyValuePairSum pair_sum; const cub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); @@ -165,13 +167,15 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : gamma_v[i] * (input_v[i] - mu) * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } @@ -186,12 +190,15 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T rsigma; // 1 / std.dev. - T gamma_v[ILP], output_v[ILP]; - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; - VecT* output_val = reinterpret_cast(&output_v); + T gamma_v[ILP], output_v[ILP]; + + if (is_valid) { + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } const T sum = BlockReduce(temp_storage).Sum(thread_data); @@ -200,11 +207,13 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = gamma_v[i] * input_v[i] * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index de3c3fb6ca065..f00239460071b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -924,55 +924,55 @@ Status LongformerQkvToContext( if (disable_compact_memory) { ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxSimpleKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } else { ORT_ENFORCE(max_num_global <= window); ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - max_num_global, - compact_global_q, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + max_num_global, + compact_global_q, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } // The temp_output is BxNxSxH, transpose it to final output BxSxNxH diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 25f3f59165e43..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -194,8 +210,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -289,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 33fa3d50e4564..c162f7133cc1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 0cdd8021de4a1..f00c112fc73d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -24,10 +24,11 @@ class TrtFusedAttention { protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; - private: + protected: bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5f..3b52320839403 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -527,6 +529,7 @@ Status FusedScaledDotProductAttentionCutlass( p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR("PackedAttention cutlass output", data.output, parameters.token_count, num_heads, v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e5..8a508241d80ba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,10 +703,12 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR_INIT(); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..2d12e975d88d7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "contrib_ops/cuda/bert/rotary_embedding.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids->Data(), + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock, + parameters.transposed); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h new file mode 100644 index 0000000000000..6dab2ad56749e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..e1b83bd8caf54 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -0,0 +1,158 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int sequence_length, + const int num_heads, + const int head_size, + const int position_ids_format, + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.z; + const int s = blockIdx.y; + const int n = blockIdx.x; + + const int i = threadIdx.x; + + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + + // Cache is (M, H/2) + const int half_head_size = head_size / 2; + const int position_id = (position_ids_format == 0) ? \ + static_cast(position_ids[0]) + s \ + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i+1 : i-1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + float* output, + const float* input, + const int64_t* position_ids, + const float* cos_cache, + const float* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + half* output, + const half* input, + const int64_t* position_ids, + const half* cos_cache, + const half* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h new file mode 100644 index 0000000000000..ee1ccc43dcbff --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 78174181acdc8..3299bc2cb11de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/common/narrow.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" #include "contrib_ops/cpu/skip_layer_norm_helper.h" @@ -50,6 +51,11 @@ template Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); + if (strict_ && skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True"); + } + const Tensor* gamma = ctx->Input(2); const Tensor* beta = Simplified ? nullptr : ctx->Input(3); @@ -57,16 +63,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const Tensor* output = ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + // Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization). + Tensor* sum_output = ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - const auto& skip_dims = skip->Shape().GetDims(); - size_t skip_dims_size = skip_dims.size(); - int hidden_size = static_cast(input_dims[input_dims_size - 1]); + int hidden_size = onnxruntime::narrow(input_dims[input_dims_size - 1]); ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, skip, @@ -76,12 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const hidden_size, input_dims_size)); - const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; - const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = onnxruntime::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + if (row_count == 0) { + return Status::OK(); + } - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); typedef typename ToCudaType::MappedType CudaT; + const int skip_size = onnxruntime::narrow(skip->Shape().Size()); + if (strict_) { HostApplyLayerNorm( GetDeviceProp(), @@ -97,21 +103,20 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { LaunchSkipLayerNormKernel( Stream(ctx), reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(gamma->Data()), (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, row_count, - skip_broadcasted, skip_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index f2ee076a8a03d..50c8e4b5e0398 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,61 +51,68 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; +constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { - size_t len = sizeof(kSizes) / sizeof(kSizes[0]); - for (size_t i = 0; i < len; ++i) { + for (size_t i = 0; i < kNumOfSizes; ++i) { if (x <= kSizes[i]) { return kSizes[i]; } } - return kSizes[len - 1]; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && reinterpret_cast(output) % alignment == 0 && - reinterpret_cast(skip_input_bias_add_output) % alignment == 0 && - reinterpret_cast(input) % alignment == 0 && reinterpret_cast(skip) % alignment == 0 && - reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - reinterpret_cast(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && + reinterpret_cast(output) % alignment == 0 && + reinterpret_cast(sum_output) % alignment == 0 && + reinterpret_cast(input) % alignment == 0 && + reinterpret_cast(skip) % alignment == 0 && + reinterpret_cast(bias) % alignment == 0 && + reinterpret_cast(gamma) % alignment == 0 && + reinterpret_cast(beta) % alignment == 0 && + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace template __global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, - const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + const int ld, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; + const bool has_bias = (bias != nullptr); + // Reduce sum of x and x^2, and the results are divided by ld. KeyValuePairSum pair_sum; - // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); - for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; - const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + T val = input[idx]; + if (has_bias) { + val += bias[i]; + } + val += skip[idx % skip_size]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = val; + if (sum_output != nullptr) { + sum_output[idx] = val; } output[idx] = val; } + if (Simplified) { SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); return; @@ -116,106 +123,114 @@ __global__ void SkipLayerNormKernel( // Vectorized kernel template __global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + int ld, int skip_size) { const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; + T sum_v[ILP]; - T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; + cub::KeyValuePair thread_data(T(0.f), T(0.f)); - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); + if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds + T skip_v[ILP], bias_v[ILP]; - VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted){ - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - }else{ - *skip_val = *reinterpret_cast(&skip[idx]); - } + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); - if (hasBias) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); - } + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - cub::KeyValuePair thread_data(T(0.f), T(0.f)); + const bool has_bias = (bias != nullptr); + if (has_bias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } - if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); + const bool has_sum_output = (sum_output != nullptr); + #pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v[i] = input_v[i]; + if (has_bias) { + sum_v[i] += bias_v[i]; } + sum_v[i] += skip_v[i]; - const T rldval = rld * input_v[i]; + const T rldval = rld * sum_v[i]; rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; + rldvalsq_sum += rldval * sum_v[i]; } - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + if (has_sum_output) { + *(reinterpret_cast(&sum_output[idx])) = *reinterpret_cast(&sum_v); } thread_data = cub::KeyValuePair(rldval_sum, rldvalsq_sum); } if (Simplified) { - SimplifiedLayerNormSmall(input_v, thread_data.value, ld, idx, gamma, epsilon, output); + SimplifiedLayerNormSmall(sum_v, thread_data.value, ld, idx, gamma, epsilon, output); return; } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); + LayerNormSmall(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output); } template void LaunchSkipLayerNormKernel( - cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { - if (row_count == 0) { - return; - } - - bool hasBias = (bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; - + cudaStream_t stream, T* output, T* sum_output, + const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon, + int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); - bool flag_vec4 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ + SkipLayerNormKernelSmall<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ + SkipLayerNormKernel<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 320) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ + } break switch (next_size) { -#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ - SkipLayerNormKernelSmall \ - <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) -#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ - SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ - } break CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); @@ -223,18 +238,27 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); + CASE_NEXT_SIZE(kSizes[8]); + CASE_NEXT_SIZE(kSizes[9]); + CASE_NEXT_SIZE(kSizes[10]); + default: { + constexpr int block_size = 256; + LAUNCH_SKIP_LAYER_NORM_KERNEL(); + break; + } + } + #undef CASE_NEXT_SIZE #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL - } } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - int ld, int row_count, bool skip_broadcasted, int skip_size); +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, T * sum_output, \ + const T* input, const T* skip, const T* bias, \ + const T* gamma, const T* beta, float epsilon, \ + int ld, int row_count, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ffb5850c827fe..9727dd6236ec8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -11,18 +11,17 @@ namespace cuda { template void LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int row_count, // number of rows. That is total number of elements divided by hidden size. - bool skip_broadcasted, // determines if broadcasting should be implemented - int skip_size); // determines size of the skip tensor + T* output, // normalized output tensor + T* sum_output, // sum of the input and skip (and bias if it exists) tensors output + const T* input, // input tensor + const T* skip, // skip tensor + const T* bias, // bias tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int row_count, // number of rows. That is total number of elements divided by hidden size. + int skip_size); // number of elements of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc new file mode 100644 index 0000000000000..3cfa3ab959343 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_expand.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/expand.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} + +template +Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape) + .IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h new file mode 100644 index 0000000000000..dedb1bdc5aa36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedExpand final : public DistributedKernel { + public: + explicit DistributedExpand(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc new file mode 100644 index 0000000000000..9008edbf3db30 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc @@ -0,0 +1,292 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "sharding.h" +#include "distributed_matmul.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +static TensorShape InferMatmulOutputShape( + const TensorShape& shape_A, + const TensorShape& shape_B) { + // left_shape: [M, K] + // right_shape: [K, N] + // output_shape: [M, N] + ORT_ENFORCE( + shape_A.NumDimensions() >= 2 && shape_B.NumDimensions() >= 2, + "1-D tensor is not supported by this MatMul."); + ORT_ENFORCE( + shape_A.NumDimensions() == shape_B.NumDimensions(), + "A and B must have the same rank after shape broadcasting."); + size_t rank = shape_A.NumDimensions(); + std::vector shape_Y(rank, 0); + for (size_t i = 0; i < rank; ++i) { + const int64_t dim_A = shape_A[i]; + const int64_t dim_B = shape_B[i]; + if (i == rank - 1) { + shape_Y[i] = dim_B; + } else if (i == rank - 2) { + shape_Y[i] = dim_A; + } else if (dim_A == 1 && dim_B >= 1) { + // dim_A is 1. + // dim_B can be either 1 or other positive integer. + // due ot shape broadcast. + shape_Y[i] = dim_B; + } else if (dim_B == 1 && dim_A >= 1) { + // dim_B is 1. + // dim_A can be either 1 or other positive integer. + // due ot shape broadcast. + shape_Y[i] = dim_A; + } else { + ORT_ENFORCE(dim_A == dim_B, "Broadcasting can only happen when one of dim_A and dim_B is 1."); + shape_Y[i] = dim_A; + } + } + return TensorShape(shape_Y); +}; + +template +DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedMatMul::ComputeInternal(OpKernelContext* context) const { + const auto tensor_shard_A = context->Input(0); + const auto tensor_shard_B = context->Input(1); + const auto& tensor_shard_shape_A = tensor_shard_A->Shape(); + const auto& tensor_shard_shape_B = tensor_shard_B->Shape(); + + auto rank_A = tensor_shard_shape_A.NumDimensions(); + auto rank_B = tensor_shard_shape_B.NumDimensions(); + // TODO(wechi): Fix MatMul(1-D, *) and MatMul(*, 1-D) cases. + ORT_ENFORCE(rank_A >= 2 && rank_B >= 2, "Broadcast rule for 1-D tensor is different than other cases."); + + const TensorPartitionSpec& spec_A = input_shard_specs_[0]; + const TensorPartitionSpec& spec_B = input_shard_specs_[1]; + const TensorPartitionSpec& spec_Y = output_shard_specs_[0]; + + const auto tensor_shape_A = ComputeOriginShape(tensor_shard_shape_A, spec_A); + const auto tensor_shape_B = ComputeOriginShape(tensor_shard_shape_B, spec_B); + + TensorShape normalized_shape_A; + TensorShape normalized_shape_B; + std::tie(normalized_shape_A, normalized_shape_B) = NormalizeShapes(tensor_shape_A, tensor_shape_B); + + TensorPartitionSpec normalized_spec_A; + TensorPartitionSpec normalized_spec_B; + std::tie(normalized_spec_A, normalized_spec_B) = NormalizeTensorPartitionSpecs(spec_A, spec_B); + + const auto tensor_shape_Y = InferMatmulOutputShape(normalized_shape_A, normalized_shape_B); + const auto tensor_shard_shape_Y = ComputeShardShape(tensor_shape_Y, spec_Y); + + // Case 1: A is not sharded, B is sharded. + // 1. shard on -1: MatMul(RR, RS) -> RS + // 2. shard on -2: MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 3. shard on other axis + if (normalized_spec_A.HasNoShard() && normalized_spec_B.HasShard()) { + if (normalized_spec_B.OnlyShardAxis(-1)) { + // Case 1-1 + // MatMul(RR, RS) -> RS + ORT_ENFORCE(spec_Y.OnlyShardAxis(-1), "Not supported yet."); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + } else if (normalized_spec_B.OnlyShardAxis(-2)) { + // Case 1-2 + // MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else { + // Case 1-3 + ORT_THROW("Not supported yet."); + } + } + + // Case 2: A is sharded, B is not sharded. + // 1. shard on -1: MatMul(RS, RR) -> MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 2. shard on -2: MatMul(SR, RR) -> SR + // 3. shard on other axis: : MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR + if (spec_A.HasShard() && spec_B.HasNoShard()) { + if (spec_A.OnlyShardAxis(-1) && spec_Y.HasNoShard()) { + // Case 2-1 + // Y is not really sharded in this case. + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + // TODO: Support cases with multi-dimension device mesh. + TensorPartitionSpec new_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B); + + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else if (spec_A.OnlyShardAxis(-2) && spec_Y.OnlyShardAxis(-2)) { + // Case 2-2 + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else if (spec_A.GetPartitionAxis() < gsl::narrow(tensor_shape_A.NumDimensions()) - 2 && normalized_spec_A.GetPartitionAxis() == spec_Y.GetPartitionAxis()) { + // Case 2-3 + if (normalized_shape_B[normalized_spec_A.GetPartitionAxis()] == 1) { + // Case 2-3-1. + // B is broadcasted to along sharding axis in A. + // E.g., MatMul(A(SRR), B(RR)) where normalized_shape_A = [2, 3, 4] and normalized_shape_B = [1, 4, 3]. + // No resharding is required. + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else { + // Case 2-3-2. + // No broadcasting + // Allocate tensor based on shape sharded non-matrix axis. + // MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + TensorPartitionSpec new_spec_B = CreateTensorShardSpec( + spec_B.device_mesh, + 0, + spec_A.GetNegativePartitionAxis(), + rank_B); + auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } + } else { + ORT_THROW("Not supported yet."); + } + } + + // Case 3: A is sharded, B is sharded. + // 1. shard on (-1, -1): MatMul(RS, RS) -> MatMul(RS, SR) + AllReduce -> RR + // -> MatMul(RR, RS) -> RS + // 2. shard on (-1, -2): MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 3. shard on (-2, -1): MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR + // 4. shard on (-2, -2): MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 5. shard on other axes + if (spec_A.HasShard() && spec_B.HasShard()) { + if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-1)) { + // Case 3-1 + if (spec_Y.HasNoShard()) { + // Case 3-1-1 + auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else if (spec_Y.OnlyShardAxis(-1)) { + // Cas 3-1-2 + auto tmp_spec_A = TensorPartitionSpec::CreateAllReplica(spec_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + } else { + ORT_THROW("Not supported yet."); + } + } else if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-2) && spec_Y.HasNoShard()) { + // Case 3-2 + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + auto status = onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y); + ORT_ENFORCE(status == Status::OK(), status.ErrorMessage()); + + status = FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y); + ORT_ENFORCE(status == Status::OK(), status.ErrorMessage()); + } else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-1)) { + // Case 3-3: + // MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR + ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet."); + + // A[RS] + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + + // B[SR] + auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B); + + // Allocate Y[RR] + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + // Run local MatMul(A[RS], B[SR]) + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-2)) { + // Case 3-4 + // MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR + ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet."); + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + auto tensor_sard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_sard_Y) == Status::OK()); + } else { + // Case 3-5 + ORT_THROW("Not supported yet."); + } + } + + // Case 4: A is not sharded, B is not sharded. + // - Easy! + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedMatMul, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedMatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedMatMul, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedMatMul); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h new file mode 100644 index 0000000000000..da07f9a8b2c7b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "sharding.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedMatMul final : public DistributedKernel { + public: + explicit DistributedMatMul(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc new file mode 100644 index 0000000000000..e413ccf580870 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -0,0 +1,861 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reshape.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +// Return true if src_shape[src_begin:src_end] is the same as +// dst_shape[dst_begin:dst_end]. Otherwise, return false. +// TODO: replace std::vector with gsl::span. +bool CompareSubVectors( + const std::vector& src_shape, + const std::vector& dst_shape, + size_t src_begin, size_t src_end, + size_t dst_begin, size_t dst_end) { + if (src_end - src_begin != dst_end - dst_begin) { + // Sub-vectors have different lengths. + return false; + } + for (size_t src_index = src_begin, dst_index = dst_begin; + src_index < src_end && dst_index < dst_end; + ++src_index, ++dst_index) { + if (src_shape[src_index] != dst_shape[dst_index]) { + // Sub-vectors have different elements. + return false; + } + } + // Sub-vectors have same length and same elements. + return true; +} + +// TODO: replace std::vector with gsl::span. +std::tuple IsTwoAxisFusion( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether two consecutive axes are fused. + // - size_t: the axis in destination shape formed by fusing two source axes. + // - size_t: the first axis fused. + // - size_t: the length of fusion. In two-axis fusion considered by this + // function, the length of fusion is always 2. + const size_t src_rank = src_shape.size(); + const size_t dst_rank = dst_shape.size(); + if (src_rank < 2 || dst_rank < 1) { + return std::make_tuple(false, -1, -1, -1); + } + if (src_rank - 1 != dst_rank) { + return std::make_tuple(false, -1, -1, -1); + } + for (size_t i_src = 0; i_src < src_rank; ++i_src) { + if (i_src + 1 > src_rank - 1) { + // We are at src_shape[i] and we need + // src_shape[i + 1] to fuse. + // If we are at the last axis, we cannot fuse. + break; + } + const int64_t prod = src_shape[i_src] * src_shape[i_src + 1]; + + for (size_t i_dst = 0; i_dst < dst_rank; ++i_dst) { + // Check if shape[i_src:i_src+2] (i.e., shape[i_src] and shape[i_src+1]) + // for source tensor are fused into shape[i_dst] for destination tensor. + if (prod != dst_shape[i_dst]) { + continue; + } + // Check if corresponding dimensions before fusion area + // are the same. + const bool prefix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[0:i_src]. + 0, i_src, + // Represent dst_shape[0:i_dst]. + 0, i_dst); + const bool suffix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[i_src+2:]. + i_src + 2, src_rank, + // Represent dst_shape[i_dst+1:]. + i_dst + 1, dst_rank); + if (prefix_shape_match && suffix_shape_match) { + return std::make_tuple( + true, i_dst, i_src, 2); + } + } + } + return std::make_tuple(false, 0, 0, 0); +} + +std::tuple IsTwoAxisDecomposition( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether one source axis is decomposed into two consecutive destination axes. + // - size_t: the axis in source shape decomposed into two consecutive destination axes. + // - size_t: the first axis the source axis decomposed into. + // - size_t: the number of decomposed axes. It's always 2 in this function. + return IsTwoAxisFusion(dst_shape, src_shape); +} + +std::vector RepeatVector(const std::vector& vec, int64_t repeat) { + std::vector new_vec; + for (int64_t i = 0; i < repeat; ++i) { + new_vec.insert(new_vec.end(), vec.begin(), vec.end()); + } + return new_vec; +} + +DeviceMesh CreateInterleaveDeviceMesh( + const DeviceMesh& source_mesh, const int64_t repeat) { + // Given a 1-D device mesh [0, 1] and repeat=2, + // return 1-D device mesh [0, 1, 0, 1]. + if (source_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source mesh shape 1-D."); + } + + // Mesh to return. + DeviceMesh new_mesh; + + std::vector& elements = new_mesh.device_mesh_elements; + for (int64_t i = 0; i < repeat; ++i) { + elements.insert( + elements.end(), + source_mesh.device_mesh_elements.begin(), + source_mesh.device_mesh_elements.end()); + } + + // source mesh must be 1-D so we only care its 1st dimension. + new_mesh.device_mesh_shape.push_back(source_mesh.device_mesh_shape[0] * repeat); + + return new_mesh; +} + +std::tuple ComputeNativeSpecForTwoAxisFusion( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t fused_axis_in_src, + const int64_t fusion_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example: RS[0], shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1, 0, 1] + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + ORT_ENFORCE(src_spec.CountShardingAxes() == 1, "Tensor to be reshaped has too many sharding axes."); + ORT_ENFORCE(src_spec.device_mesh.device_mesh_shape.size() == 1, "Source device mesh be 1-D."); + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src)) { + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example 1: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 0, 0, 0, 0], (device assignment) + // [1, 1, 1, 1, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[ 0, 1, 2, 3, 4, 5, 6, 7]] + // - Device 1's local tensor (shape: [2, 4]). + // [[ 8, 9, 10, 11, 12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from shape [2, 4] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 2: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 4, 5], + // [ 6, 7]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 3: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [1, 1], + // [1, 1], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 8, 9], + // [10, 11]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 4, 5], + // [ 6, 7], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor. + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1, 0, 1] + + // Reuse original device mesh but shard the fusion axis in output tensor. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), src_spec.device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src + 1)) { + // Example 1 of determining native output sharding spec: + // - logical input shape: [3, 4] + // - logical output shape: [12] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 1, 0, 1], (device assignment) + // [0, 1, 0, 1], + // [0, 1, 0, 1]] + // [[0, 1, 2, 3], (values) + // [4, 5, 6, 7], + // [8, 9, 10, 11]], + // - Device 0's local tensor. + // [[0, 0], + // [0, 0], + // [0, 0]] + // [[0, 2], + // [4, 6], + // [8, 10]], + // - Device 1's local tensor. + // [[1, 1], + // [1, 1], + // [1, 1]] + // [[1, 3], + // [5, 7], + // [9, 11]], + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [6] by fusing both axes in shape [3, 2]. + // 3. Run local reshape (reshape from [3, 2] to [6]): + // - Device 0's local output tensor. + // [0, 0, 0, 0, 0, 0] + // [0, 2, 4, 6, 8, 10] + // - Device 1's local output tensor. + // [1, 1, 1, 1, 1, 1] + // [1, 3, 5, 7, 9, 11] + // 4. Determine native output sharding spec by comparing local output tensors and logical tensor. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 2 of determining native output sharding spec: + // - logical input shape: [3, 8] + // - logical output shape: [24] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15], + // [16, 17, 18, 19, 20, 21, 22, 23]] + // - Device 0's local tensor (shape: [3, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13], + // [16, 17, 20, 21]] + // - Device 1's local tensor (shape: [3, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15], + // [18, 19, 22, 23]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [12] by fusing both axes in shape [3, 4]. + // 3. Run local reshape (reshape from [3, 4] to [12]): + // - Device 0's local output tensor . + // [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21] + // - Device 1's local output tensor . + // [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = . + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 3: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 4, 5, 8, 9, 12, 13] + // - Device 1's local output tensor . + // [ 2, 3, 6, 7, 10, 11, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 2 + // + // Example 4: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 1, 1, 1, 1], (device assignment) + // [0, 0, 0, 0, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 2, 3], + // [ 8, 9, 10, 11]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 4, 5, 6, 7], + // [12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor . + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1] = [0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1] * (first fused dimension) = [0, 1] * 2 = [0, 1, 0, 1] + + // The output device mesh is the repeats of the original device. + // Let's use Python syntax. If the original device mesh is [0, 1, 0, 1], and + // the first fused dimension is 3, then the output device mesh is [0, 1, 0, 1] * 3. + auto dst_device_mesh = DeviceMesh::Create1D( + src_spec.device_mesh.device_mesh_elements, + src_shape[fused_axis_in_src]); + // Sharding happens in the fusion axis with the new device mesh. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), dst_device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && (src_spec.GetPartitionAxis() < fused_axis_in_src || src_spec.GetPartitionAxis() > fused_axis_in_src + 1)) { + // It's two-axis fusion but the fused axes is not sharded. + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + auto dst_spec = TensorPartitionSpec::CreateByDropAxes( + src_spec, {fused_axis_in_src + 1}); + return std::make_tuple(true, dst_spec); + } else { + return std::make_tuple(false, TensorPartitionSpec()); + } +} + +// Arguments: +// - device_elements: a vector of device IDs. +// It should only contain unique device IDs or +// repeats of a list of unique device IDs. Otherwise, +// (0, 0) is returned. +// Returns: +// - count per device ID (all device IDs should have the same count) +// - number of unique device IDs +// Examples: +// - [0, 1] -> (2, 1) +// - [0, 1, 2, 0, 1, 2] -> (2, 3) +std::tuple ComputeRepeatAndRepeatStride( + const std::vector& device_elements) { + int64_t first_device_id = device_elements.at(0); + int64_t first_device_id_count = 0; + for (size_t i = 0; i < device_elements.size(); ++i) { + if (device_elements.at(i) == first_device_id) { + ++first_device_id_count; + } + } + size_t repeat_stride = device_elements.size() / first_device_id_count; + + // Check if the device mesh pattern is supported. + // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. + // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. + for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { + for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + ORT_ENFORCE( + device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), + "Unsupported device mesh pattern."); + } + } + + // If device_mesh=[0, 1, 2, 0, 1, 2], returns (2, 3), which means + // - each device repeats twice for "2" in (2, 3). + // - there are 3 unique devices for "3" in (2, 3). + return std::make_tuple(first_device_id_count, repeat_stride); +} + +std::tuple ComputeNativeSpecForTwoAxisDecomposition( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t decomposed_axis_in_src, + const int64_t decomposition_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0], shape=[8], device_mesh=[0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1] -> RS[0] + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> RS[0] + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RS[0]RR + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RRS[0]R + if (src_spec.CountShardingAxes() != 1) { + throw std::runtime_error("Too many sharding axes."); + } + if (src_spec.device_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source device mesh be 1-D."); + } + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.OnlyShardAxis(decomposed_axis_in_src)) { + const int64_t device_stride = src_shape[decomposed_axis_in_src] / src_spec.device_mesh.device_mesh_shape[0]; + if (device_stride >= dst_shape[decomposition_axis_in_dst + 1] && device_stride % dst_shape[decomposition_axis_in_dst + 1] == 0) { + // Since 2nd decomposition dimension is a factor of device stride, + // Sharding happens at 1st decomposition axis in dst. + // device_stride = 10 + // S[0], shape=[20], device=[0, 1] -> S[0]R, shape=[2, 10], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> RS[0], shape=[1, 16], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + // Sharding spec is copied if the axis is not decomposed. + // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] + // The spec for "5" is copied. + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // The spec for "5" is copied and "1" is replica. + // This reshape only adds a dummy new axis without affecting + // the underlying sharding status. + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + } else { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + // Now, we know sharding happens at decomposed_axis_in_src axis in destination tensor. + // - effective_device_stride along decomposed_axis_in_src: device_stride / dst_shape[decomposed_axis_in_src + 1] + // - The original device patterns repeats: dst_shape[decomposed_axis_in_src] / effective_device_stride times. + const int64_t effective_device_stride = device_stride / dst_shape[decomposed_axis_in_src + 1]; + // How many times a device ID changes along decomposed_axis_in_src axis in destination tensor. + const int64_t number_of_device_changes = dst_shape[decomposed_axis_in_src] / effective_device_stride; + if ((size_t)number_of_device_changes != src_spec.device_mesh.device_mesh_elements.size()) { + throw std::runtime_error("Not supported. Resharding is required."); + } + auto dst_device_mesh = CreateInterleaveDeviceMesh( + src_spec.device_mesh, 1); + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else if (dst_shape[decomposition_axis_in_dst + 1] > device_stride && dst_shape[decomposition_axis_in_dst + 1] % device_stride == 0) { + // Since 2nd decomposition dimension is a multiple of device stride, + // sharding happens at 2nd decomposition axis in dst. + // stride = 4 + // S[0], shape=[8], device=[0, 1] -> S[0]R, shape=[4, 2], device=[0, 1] + // + // stride = 8 + // S[0], shape=[32], device=[0, 1, 0, 1] -> RS[0], shape=[2, 16], device=[0, 1] + std::vector dst_axis_specs; + // How many times a device ID appears. + // E.g., [0, 1, 0, 1, 0, 1] -> 3 + int64_t repeats = 0; + // Number of unique devices. + // E.g., [0, 1, 0, 1, 0, 1] -> 2 + int64_t repeat_stride = 0; + DeviceMesh dst_device_mesh; + std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (dst_shape[decomposition_axis_in_dst + 1] == 1) { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats == 1 && dst_shape[decomposition_axis_in_dst + 1] == device_stride * repeat_stride) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats != 1 && dst_shape[decomposition_axis_in_dst + 1] % (device_stride * repeat_stride) == 0) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + // Extract [0, 1] from [0, 1, 0, 1]. + std::vector unique_device_mesh_elements( + src_spec.device_mesh.device_mesh_elements.begin(), + src_spec.device_mesh.device_mesh_elements.begin() + repeat_stride); + // Compute new repeats. + // Example of repeats change from 2 to 1: + // [16]-shape tensor [2, 8]-shape tensor + // with 1-D device mesh -> Reshape -> with 1-D device mesh + // [0, 1, 0, 1] (repeats=2) [0, 1] (repeats=1) + const int64_t new_repeat = dst_shape[decomposition_axis_in_dst + 1] / (device_stride * repeat_stride); + dst_device_mesh.device_mesh_shape.push_back(repeat_stride); + dst_device_mesh.device_mesh_elements = RepeatVector(unique_device_mesh_elements, new_repeat); + } else { + throw std::runtime_error("Not supported. Resharding is required."); + } + } + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else { + // Not supported. Resharding is required. + return std::make_tuple(false, TensorPartitionSpec()); + } + } else { + // Source tensor is sharded on non-decomposed axis. + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else { + // R -> RR + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, src_spec.device_mesh)); + } +} + +// Arguments: +// global_data_shape: logical shape of Reshape's 1st input. +// global_shape_span: logical content of Reshape's 2nd input. +// Returns: +// logical shape of Reshape's output. +inline TensorShape InferDistributedReshapeLogicalOutputShape( + const TensorShape& global_data_shape, + const gsl::span& global_shape_span, + const int64_t allow_zero) { + return onnxruntime::cuda::InferReshapeOutputShape( + global_data_shape, + global_shape_span, + allow_zero); +} + +template +DistributedReshape::DistributedReshape(const OpKernelInfo& info) : DistributedKernel(info) { + allow_zero_ = info.GetAttrOrDefault("allowzero", static_cast(0)); +} + +template +Status DistributedReshape::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + auto data_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& data_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + if (data_sharding_spec.HasNoShard() && shape_sharding_spec.HasNoShard() && output_sharding_spec.HasNoShard()) { + // Case: all inputs and outputs are not sharded. + const auto target_shape = onnxruntime::cuda::InferReshapeOutputShape( + data_tensor, + shape_tensor, + allow_zero_); + + auto output_tensor = context->Output(0, target_shape); + + // Copy data from input from output. + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "Shape tensor should not be sharded because it will trigger communication. " + "If sharding shape is needed, please request this feature on Github."); + ORT_ENFORCE(shape_tensor->Shape().NumDimensions() == 1, "Shape must be a 1-D tensor."); + const auto original_data_shape = ComputeOriginShape(data_tensor->Shape(), data_sharding_spec); + const auto original_output_shape = InferDistributedReshapeLogicalOutputShape( + original_data_shape, + shape_tensor->template DataAsSpan(), + allow_zero_); + + // TODO: remove below code after replacing std::vector with TensorShape in other APIs. + std::vector src_shape(original_data_shape.GetDims().begin(), original_data_shape.GetDims().end()); + std::vector dst_shape(original_output_shape.GetDims().begin(), original_output_shape.GetDims().end()); + + // Case: Two axis fusion + bool is_two_axis_fusion = false; + size_t two_axis_fusion_axis_in_dst = 0; + size_t two_axis_fusion_first_fused_axis_in_src = 0; + size_t two_axis_fusion_fused_axis_count = 0; + std::tie( + is_two_axis_fusion, + two_axis_fusion_axis_in_dst, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_fused_axis_count) = IsTwoAxisFusion(src_shape, dst_shape); + + if (is_two_axis_fusion) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisFusion( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + + // Case: Two axis decomposition + bool is_two_axis_decomposition = false; + size_t two_axis_decomposition_decomposed_axis_in_src = 0; + size_t two_axis_decomposition_first_factor_axis_in_dst = 0; + size_t two_axis_decomposition_factor_axis_count_in_dst = 0; + std::tie( + is_two_axis_decomposition, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst, + two_axis_decomposition_factor_axis_count_in_dst) = IsTwoAxisDecomposition(src_shape, dst_shape); + + if (is_two_axis_decomposition) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisDecomposition( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + } + + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h new file mode 100644 index 0000000000000..e251c3cdc38d7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/tensor/reshape.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReshape final : public DistributedKernel { + public: + explicit DistributedReshape(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t allow_zero_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc new file mode 100644 index 0000000000000..5768dba791292 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_slice.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedSlice::DistributedSlice(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedSlice::ComputeInternal(OpKernelContext* context) const { + const auto tensor_shard_data = context->Input(0); + const auto tensor_shard_starts = context->Input(1); + const auto tensor_shard_ends = context->Input(2); + + const TensorPartitionSpec& spec_data = input_shard_specs_[0]; + const TensorPartitionSpec& spec_starts = input_shard_specs_[1]; + const TensorPartitionSpec& spec_ends = input_shard_specs_[2]; + const TensorPartitionSpec& spec_Y = output_shard_specs_[0]; + + const auto tensor_shard_axes = context->Input(3); + const TensorPartitionSpec& spec_axes = input_shard_specs_[3]; + + if (spec_starts.HasShard() || + spec_ends.HasShard() || + spec_axes.HasShard() || + (input_shard_specs_.size() > 4 && input_shard_specs_[4].HasShard())) + ORT_THROW("DistributedSlice: shard on starts / ends / axes / steps are not supported yet."); + + std::vector input_starts; + std::vector input_ends; + auto starts_data = tensor_shard_starts->DataAsSpan(); + input_starts.resize(starts_data.size()); + std::copy(starts_data.begin(), starts_data.end(), input_starts.begin()); + auto ends_data = tensor_shard_ends->DataAsSpan(); + input_ends.resize(ends_data.size()); + std::copy(ends_data.begin(), ends_data.end(), input_ends.begin()); + + std::vector input_axes; + if (tensor_shard_axes) { + auto axes_data = tensor_shard_axes->DataAsSpan(); + input_axes.resize(axes_data.size()); + std::copy(axes_data.begin(), axes_data.end(), input_axes.begin()); + } + + std::vector input_steps; + const auto tensor_shard_steps = context->Input(4); + if (tensor_shard_steps) { + const TensorPartitionSpec& spec_steps = input_shard_specs_[4]; + if (spec_steps.HasShard()) + ORT_THROW("Not supported yet."); + + auto steps_data = tensor_shard_steps->DataAsSpan(); + input_steps.resize(steps_data.size()); + std::copy(steps_data.begin(), steps_data.end(), input_steps.begin()); + } + + if (spec_data.GetPartitionAxis() != -1 && + std::find(input_axes.begin(), input_axes.end(), spec_data.GetPartitionAxis()) != input_axes.end()) { + // shard on slice axes, reshard first + auto tmp_spec_data = TensorPartitionSpec::CreateAllReplica(spec_data); + auto tensor_data = ReshardTensor(this, context, spec_data, tmp_spec_data, nccl_->Rank(), tensor_shard_data); + + const auto& input_shape = tensor_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.HasNoShard()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + auto tmp_spec_output = TensorPartitionSpec::CreateAllReplica(spec_Y); + ReshardTensor(this, context, tmp_spec_output, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } else { + const auto& input_shape = tensor_shard_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.GetPartitionAxis() == spec_data.GetPartitionAxis()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_shard_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + ReshardTensor(this, context, spec_data, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h new file mode 100644 index 0000000000000..48c77eee241de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedSlice final : public DistributedKernel { + public: + explicit DistributedSlice(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc new file mode 100644 index 0000000000000..c3cae2d0bf8ca --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_squeeze.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedSqueeze::DistributedSqueeze(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedSqueeze::ComputeInternal(OpKernelContext* context) const { + auto input_tensor = context->Input(0); + auto axes_tensor = context->Input(1); + auto axes_span = axes_tensor->DataAsSpan(); + + const TensorPartitionSpec& input_spec = input_shard_specs_[0]; + const TensorPartitionSpec& axes_spec = input_shard_specs_[1]; + const TensorPartitionSpec& output_spec = output_shard_specs_[0]; + + ORT_ENFORCE(axes_spec.HasNoShard(), "Axes tensor cannot be sharded."); + + // Non-negative collection of axes to drop. + std::vector axes; + for (const auto axis : axes_span) { + axes.push_back(axis >= 0 ? axis : axis + input_tensor->Shape().NumDimensions()); + } + // Shape after dropping axes. + auto dims = input_tensor->Shape().AsShapeVector(); + // Sort in descending order so that we can drop axes from the end. + std::sort(axes.begin(), axes.end(), [](Tind a, Tind b) { return a > b; }); + for (const auto axis : axes) { + ORT_ENFORCE(input_tensor->Shape()[axis] == 1, "Cannot squeeze non-singleton dimension."); + dims.erase(dims.begin() + axis); + } + auto native_output_spec = TensorPartitionSpec::CreateByDropAxes( + input_spec, + axes); + ORT_ENFORCE( + output_spec == native_output_spec, + "Re-sharding is required but not supported yet for this case."); + auto output_tensor = context->Output(0, dims); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + output_tensor->MutableDataRaw(), + input_tensor->DataRaw(), + input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h new file mode 100644 index 0000000000000..5b81d9c4792bd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedSqueeze final : public DistributedKernel { + public: + explicit DistributedSqueeze(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc new file mode 100644 index 0000000000000..a78f19101b0da --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_unsqueeze.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedUnsqueeze::DistributedUnsqueeze(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedUnsqueeze::ComputeInternal(OpKernelContext* context) const { + auto input_tensor = context->Input(0); + auto axes_tensor = context->Input(1); + auto axes_span = axes_tensor->DataAsSpan(); + + const TensorPartitionSpec& input_spec = input_shard_specs_[0]; + const TensorPartitionSpec& axes_spec = input_shard_specs_[1]; + const TensorPartitionSpec& output_spec = output_shard_specs_[0]; + + ORT_ENFORCE(axes_spec.HasNoShard(), "Axes tensor cannot be sharded."); + + std::vector axes(axes_span.begin(), axes_span.end()); + std::sort(axes.begin(), axes.end()); + auto dims = input_tensor->Shape().AsShapeVector(); + auto native_output_spec = input_spec; + for (auto axis : axes) { + if (axis < 0) { + axis += input_tensor->Shape().NumDimensions() + 1; + } + dims.insert(dims.begin() + axis, 1); + native_output_spec = TensorPartitionSpec::CreateByInsertOneAxis( + native_output_spec, + axis); + } + ORT_ENFORCE( + output_spec == native_output_spec, + "Re-sharding is required but not supported yet for this case. ", + "Specified: ", output_spec.ToString(), + " Actual: ", native_output_spec.ToString()); + auto output_tensor = context->Output(0, dims); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + output_tensor->MutableDataRaw(), + input_tensor->DataRaw(), + input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h new file mode 100644 index 0000000000000..005093ef78fb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedUnsqueeze final : public DistributedKernel { + public: + explicit DistributedUnsqueeze(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index bb924a0d49cfe..574a3133de815 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -1,10 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "nccl_kernels.h" #include "mpi_include.h" +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" #include "core/providers/cuda/tensor/transpose.h" #include "core/providers/cuda/cuda_check_memory.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { @@ -35,21 +49,160 @@ static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { } } -#ifdef USE_MPI -static Status CreateNcclCommByMPI(int world_size, int rank, ncclComm_t* comm) { +namespace IPC { +#define FLLOG LOGS_DEFAULT(VERBOSE) +#define FLLOGERRNO LOGS_DEFAULT(WARNING) << "error:" << strerror(errno) +#define FLLOGGAI LOGS_DEFAULT(WARNING) << "error:" << gai_strerror(ret) + +typedef std::shared_ptr AddrInfoPtr; + +int CreateSocket(bool is_server) { + int sockfd = -1; + + struct addrinfo hints; + struct addrinfo* result = nullptr; + AddrInfoPtr result_ptr(result, [](struct addrinfo* p) { if(p){freeaddrinfo(p);} }); + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */ + hints.ai_socktype = SOCK_STREAM; /* TCP socket. use SOCK_DGRAM for UDP */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + + std::string rank0_ip = ParseEnvironmentVariableWithDefault("RANK0_IP", "localhost"); + std::string port_number = ParseEnvironmentVariableWithDefault("RANK0_PORT", "18888"); + + int ret = getaddrinfo(is_server ? nullptr : rank0_ip.c_str(), port_number.c_str(), &hints, &result); + if (ret != 0) { + FLLOGGAI << " getaddrinfo failed\n"; + return sockfd; + } + + for (struct addrinfo* rp = result; rp != nullptr; rp = rp->ai_next) { + sockfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (sockfd == -1) { + continue; + } + + int on = 1; + int rc = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on)); + if (rc < 0) { + FLLOGERRNO << ("setsockopt() failed\n"); + close(sockfd); + continue; + } + + if (is_server) { + if (bind(sockfd, rp->ai_addr, rp->ai_addrlen) == 0) { + FLLOG << "Listening on port " << port_number << " for the other GPU processores...\n"; + } else { + FLLOGERRNO << ("bind failed\n"); + close(sockfd); + sockfd = -1; + } + } else { + time_t start_time = time(0); + int conn_ret = connect(sockfd, rp->ai_addr, rp->ai_addrlen); + while (time(0) - start_time < 40 && conn_ret < 0) { + FLLOGERRNO << (" waiting the RANK 0 ready...\n"); /* terminate */ + sleep(1); + conn_ret = connect(sockfd, rp->ai_addr, rp->ai_addrlen); + } + if (conn_ret < 0) { + close(sockfd); + sockfd = -1; + FLLOGERRNO << ("connect failed with timeout\n"); /* terminate */ + } else { + FLLOG << "connect to " << rank0_ip << ":" << port_number << "success \n"; + } + } + break; + } + return sockfd; +} + +int WriteOnRank0(ncclUniqueId* nccl_id, int word_size) { + int fd = CreateSocket(true); + if (fd < 0) { + FLLOGERRNO << (" create socket\n"); /* terminate */ + return -1; + } + + /* listen to the socket */ + if (listen(fd, word_size) < 0) { + FLLOGERRNO << ("listen\n"); /* terminate */ + return -1; + } + + word_size--; // rank 0 is not in word_size + while (word_size-- > 0) { + int client_fd = accept(fd, nullptr, nullptr); /* accept blocks */ + if (client_fd < 0) { + FLLOGERRNO << ("accept\n"); /* terminate */ + return -1; + } + FLLOG << ("Accepted new GPU\n"); + if (write(client_fd, (nccl_id), sizeof(ncclUniqueId)) != sizeof(ncclUniqueId)) { + FLLOGERRNO << ("write\n"); /* terminate */ + return -1; + } + close(client_fd); + } + close(fd); + return 0; +} + +int ReadFromRank0(ncclUniqueId* nccl_id) { + int sockfd = CreateSocket(false); + if (sockfd < 0) { + FLLOGERRNO << ("socket"); + return -1; + } + + if (read(sockfd, (nccl_id), sizeof(ncclUniqueId)) != sizeof(ncclUniqueId)) { + FLLOGERRNO << ("read"); /* terminate */ + return -1; + } + + close(sockfd); + return 0; +} + +int IPC_Bcast(ncclUniqueId* nccl_id, int rank, int world_size) { + if (rank == 0) { + if (WriteOnRank0(nccl_id, world_size) != 0) { + return (-1); + } + } else if (ReadFromRank0(nccl_id) != 0) { + return (-1); + } + + return 0; +} +} // namespace IPC + +static Status CreateNcclCommunicator(int world_size, int rank, ncclComm_t* comm, bool is_launched_by_mpi) { // Create new NCCL communicator ncclUniqueId nccl_id; if (rank == 0) { NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); } - MPI_CHECK(MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + if (is_launched_by_mpi) { +#ifdef USE_MPI + MPI_CHECK(MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Please compile ORT with USE_MPI."); +#endif + } else if (IPC::IPC_Bcast(&nccl_id, rank, world_size) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "IPC_Bcast nccl_id failed with :", strerror(errno)); + } NCCL_RETURN_IF_ERROR(ncclCommInitRank(comm, world_size, nccl_id, rank)); return Status::OK(); } -#endif NcclContext::NcclContext() { + world_size_ = -1; #ifdef USE_MPI int is_mpi_initialized = 0; MPI_Initialized(&is_mpi_initialized); @@ -62,14 +215,19 @@ NcclContext::NcclContext() { MPI_Comm_size(MPI_COMM_WORLD, &world_size_); MPI_Comm_rank(MPI_COMM_WORLD, &rank_); +#endif + // world_size_ would be zero if MPI is being compiled but not launched by MPI. + bool is_launched_by_mpi = true; + if (world_size_ < 1) { + is_launched_by_mpi = false; + world_size_ = ParseEnvironmentVariableWithDefault("LOCAL_WORLD_SIZE", -1); + rank_ = ParseEnvironmentVariableWithDefault("LOCAL_RANK", -1); + ORT_ENFORCE(world_size_ != -1 && rank_ != -1); + } // Initialize global Parallel Group NCCL Communicator - auto ret = CreateNcclCommByMPI(world_size_, rank_, &comm_); + auto ret = CreateNcclCommunicator(world_size_, rank_, &comm_, is_launched_by_mpi); ORT_ENFORCE(ret.IsOK()); - -#else - ORT_THROW("ORT must be built with MPI to use NCCL."); -#endif } NcclContext::~NcclContext() { @@ -246,6 +404,116 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), AllToAll); +Status FuncAllReduce( + ncclComm_t comm, + cudaStream_t stream, + const Tensor* input, + Tensor* output) { + const void* input_data = input->DataRaw(); + const auto input_shape = input->Shape(); + int64_t input_count = input_shape.Size(); + + void* output_data = output->MutableDataRaw(); + + ncclDataType_t dtype = GetNcclDataType(input->DataType()); + NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream)); + return Status::OK(); +} + +static std::vector CalculatePermToSwapAxes( + const int64_t axis, + const int64_t another_axis, + const size_t rank) { + // This is for swapping axis and another_axis. + // NCCL's AllGather only gathers along axis 0. If gathering along another axis is needed, + // we need to call transpose. E.g., + // Case 1: + // AllGather(axis=0) + // Case 2: + // AllGather(axis=3) = Transpose(perm=[3, 1, 2, 0]) -> AllGather(axis=0) -> Transpose(perm=[3, 1, 2, 0]) + std::vector permutation(rank); + std::iota(std::begin(permutation), std::end(permutation), 0); + permutation[axis] = another_axis; + permutation[another_axis] = axis; + return permutation; +} + +void FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis, + Tensor* output) { + ORT_ENFORCE(output->Shape().Size() == input->Shape().Size() * group_size, "Input and output shapes mismatch."); + ORT_ENFORCE(group_size >= 0, "group_size should be non-negative."); + ORT_ENFORCE(axis >= 0, "axis should be non-negative."); + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK(), "Fail to find allocator."); + if (axis == 0) { + const void* input_data = input->DataRaw(); + const auto input_shape = input->Shape(); + void* output_data = output->MutableDataRaw(); + ncclAllGather( + input_data, + output_data, + input_shape.Size(), + GetNcclDataType(input->DataType()), + nccl_kernel->Comm(), + nccl_kernel->Stream(ctx)); + } else { + const auto source_shape = input->Shape(); + TensorShape transposed_shape(source_shape); + transposed_shape[0] = source_shape[axis]; + transposed_shape[axis] = source_shape[0]; + + auto transposed_buffer = Tensor::Create(input->DataType(), transposed_shape, alloc); + + // swap axis 0 and axis axis + std::vector perm = CalculatePermToSwapAxes(0, axis, source_shape.NumDimensions()); + + ORT_ENFORCE(onnxruntime::cuda::Transpose::DoTranspose(nccl_kernel->GetDeviceProp(), + nccl_kernel->Stream(ctx), + nccl_kernel->GetCublasHandle(ctx), + perm, *input, *transposed_buffer) == Status::OK()); + + TensorShape gathered_shape(transposed_shape); + gathered_shape[0] = group_size * transposed_shape[0]; + auto gathered_buffer = Tensor::Create(input->DataType(), gathered_shape, alloc); + + ncclAllGather( + transposed_buffer->DataRaw(), + gathered_buffer->MutableDataRaw(), + transposed_shape.Size(), + GetNcclDataType(input->DataType()), + nccl_kernel->Comm(), + nccl_kernel->Stream(ctx)); + + ORT_ENFORCE(gathered_buffer->Shape().Size() == output->Shape().Size()); + ORT_ENFORCE(onnxruntime::cuda::Transpose::DoTranspose(nccl_kernel->GetDeviceProp(), + nccl_kernel->Stream(ctx), + nccl_kernel->GetCublasHandle(ctx), + perm, *gathered_buffer, *output) == Status::OK()); + } +} + +std::unique_ptr FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis) { + ORT_ENFORCE(group_size >= 0); + ORT_ENFORCE(axis >= 0); + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + TensorShape output_shape(input->Shape()); + output_shape[axis] = group_size * output_shape[axis]; + auto output = Tensor::Create(input->DataType(), output_shape, alloc); + FuncAllGather(nccl_kernel, ctx, input, group_size, axis, output.get()); + return output; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 24df69ea50224..7fc26e6be57b9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -6,7 +6,12 @@ #include "core/providers/cuda/cuda_kernel.h" #if defined(ORT_USE_NCCL) +#include +#include +#include +#include #include +#include #endif namespace onnxruntime { @@ -44,6 +49,10 @@ class NcclKernel : public ::onnxruntime::cuda::CudaKernel { public: explicit NcclKernel(const OpKernelInfo& info); + ncclComm_t Comm() const { + return nccl_->Comm(); + } + protected: NcclContext* nccl_ = nullptr; }; @@ -81,6 +90,27 @@ class AllToAll final : public NcclKernel { int64_t group_size_ = -1; }; +Status FuncAllReduce( + ncclComm_t comm, + cudaStream_t stream, + const Tensor* input, + Tensor* output); + +void FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis, + Tensor* output); + +std::unique_ptr FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc new file mode 100644 index 0000000000000..b6b509023a1a9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding.h" +#include "mpi_include.h" +#include "sharding_spec.h" + +#include +#include +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void GatherTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor, + Tensor* gathered) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); + + FuncAllGather( + nccl_kernel, + ctx, + tensor, + shard_count, + shard_axis, + gathered); +} + +std::unique_ptr GatherTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); + TensorShape gathered_shape(tensor->Shape()); + gathered_shape[shard_axis] *= shard_count; + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto gathered = Tensor::Create(tensor->DataType(), gathered_shape, alloc); + + FuncAllGather( + nccl_kernel, + ctx, + tensor, + shard_count, + shard_axis, + gathered.get()); + + return gathered; +} + +void ShardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor, + Tensor* shard_tensor) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); + TensorShape shard_shape = ComputeShardShape( + tensor->Shape(), + shard_axis, + shard_count); + const int64_t shard_dim = shard_shape[shard_axis]; + const std::vector starts = {shard_dim * device_id}; + const std::vector ends = {shard_dim * (device_id + 1)}; + const std::vector axes = {shard_axis}; + const std::vector steps = {1}; + + ORT_ENFORCE(FuncSlice( + nccl_kernel, + ctx, + tensor, + starts, + ends, + axes, + steps, + shard_tensor) == Status::OK()); +} + +std::unique_ptr ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor) { + // Shard all-replica tensor per spec. + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + + TensorShape shard_shape = ComputeShardShape( + tensor->Shape(), + spec.GetPartitionAxis(), + spec.GetUniqueDeviceCount(spec.GetPartitionAxis())); + auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); + + // Shard with pre-allocated buffer. + ShardTensor( + nccl_kernel, + ctx, + spec, + device_id, + tensor, + shard_buffer.get()); + + return shard_buffer; +} + +void ReshardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + Tensor* dst) { + if (src_spec.HasShard() && dst_spec.HasNoShard()) { + GatherTensor( + nccl_kernel, + ctx, + src_spec, + src, + dst); + return; + } else if (src_spec.HasNoShard() && dst_spec.HasShard()) { + ShardTensor( + nccl_kernel, + ctx, + dst_spec, + device_id, + src, + dst); + } else if (src_spec.HasShard() && dst_spec.HasShard()) { + int64_t src_axis = src_spec.GetPartitionAxis(); + int64_t dst_axis = dst_spec.GetPartitionAxis(); + ORT_ENFORCE(src_axis != dst_axis, "No reshard is needed. Don't call this function."); + + auto all_replica_buffer = GatherTensor( + nccl_kernel, + ctx, + src_spec, + src); + + ShardTensor( + nccl_kernel, + ctx, + dst_spec, + device_id, + all_replica_buffer.get(), + dst); + } else { + ORT_THROW("Not supported yet. Probably resharding is not needed."); + } +} + +std::unique_ptr ReshardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src) { + // Implement ReshardTensor but returning a unique_ptr to Tensor instead. + const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec); + const auto dst_shape = ComputeShardShape(origin_shape, dst_spec); + ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString()); + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst = Tensor::Create(src->DataType(), dst_shape, alloc); + ReshardTensor( + nccl_kernel, + ctx, + src_spec, + dst_spec, + device_id, + src, + dst.get()); + return dst; +} + +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx) { + // Implement ReshardTensor but returning a unique_ptr to Tensor instead. + const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec); + const auto dst_shape = ComputeShardShape(origin_shape, dst_spec); + ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString()); + + auto* dst = ctx->Output(output_idx, dst_shape); + ReshardTensor( + nccl_kernel, + ctx, + src_spec, + dst_spec, + device_id, + src, + dst); +} + +DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { + // input_device_mesh_shapes[i] is the shape of device mesh for the i-th input. + // E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is + // stored on a 1-D mesh with 2 devices and the second input on another 1-D + // mesh with 1 device. + std::vector attr_input_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK()); + + // input_device_mesh_elements[i] is the flattened device mesh for the i-th input. + // Note that its actual shape is input_device_mesh_shapes[i]. + // Example: + // Assume + // device_mesh_shapes = ["[2]", "[1]"] + // device_mesh_elements = ["[0,1]", "[0]"] + // Then the first input is stored on a 1-D mesh with 2 devices and the second + // input on another 1-D mesh with 1 device. + std::vector attr_input_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK()); + + // input_shard_specs[i] is the sharding spec of the i-th input; e.g., + // "RR" if the i-th input is not sharded. + std::vector input_shard_specs; + ORT_ENFORCE(info.GetAttrs("input_shard_specs", input_shard_specs).IsOK()); + + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size()); + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size()); + + // Begin parsing sharding metadata for inputs. + for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]); + auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); + input_shard_specs_.push_back(spec); + } + + std::vector attr_output_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK()); + + std::vector attr_output_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK()); + + std::vector output_shard_specs; + ORT_ENFORCE(info.GetAttrs("output_shard_specs", output_shard_specs).IsOK()); + + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size()); + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size()); + + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]); + auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); + output_shard_specs_.push_back(spec); + } +} + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.h b/onnxruntime/contrib_ops/cuda/collective/sharding.h new file mode 100644 index 0000000000000..81a0f72f0c32f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "sharding_spec.h" +#include "nccl_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void GatherTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor, + Tensor* gathered); + +std::unique_ptr GatherTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor); + +void ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor, + Tensor* shard_tensor); + +std::unique_ptr ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor); + +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + Tensor* dst); + +// Output from ctx +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx); + +std::unique_ptr ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src); + +class TensorPartitionSpec; + +class DistributedKernel : public NcclKernel { + public: + explicit DistributedKernel(const OpKernelInfo& info); + + protected: + std::vector input_shard_specs_; + std::vector output_shard_specs_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc new file mode 100644 index 0000000000000..20c936e1b6718 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" + +#include "core/common/common.h" +#include "core/common/gsl.h" +#include "core/framework/tensor_shape.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void ValidateAxisIndex(const int64_t axis, const int64_t rank) { + int64_t adjusted_axis = axis; + if (axis < 0) { + adjusted_axis = axis + rank; + } else { + adjusted_axis = axis; + } + ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ")."); +} + +std::vector ParseStringAsInt64Vector(const std::string& str) { + if (str.empty() || str.front() != '[' || str.back() != ']') { + throw std::invalid_argument("Invalid input string format"); + } + // Parsed vector. + // If input is "[0, 1, 2]", result should be {0, 1, 2}. + std::vector result; + // Skip '[' and ']' + std::istringstream iss(str.substr(1, str.size() - 2)); + + // Extract integers separated by ',' or whitespaces. + int64_t num = -1; + while (/* Read one number at a time */ iss >> num) { + result.push_back(num); + // Skip the comma + if (iss.peek() == ',') { + iss.ignore(); + } + } + return result; +} + +DeviceMesh CreateDeviceMesh( + std::vector device_mesh_shape, + std::vector device_mesh_elements) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape = device_mesh_shape; + device_mesh.device_mesh_elements = device_mesh_elements; + return device_mesh; +} + +TensorPartitionSpec CreateTensorPartitionSpec(std::string spec_string, std::vector device_mesh_shape, std::vector device_mesh_elements) { + // "S[0]R" + std::vector axis_specs; + size_t dim_index = 0; + size_t token_index = 0; + while (token_index < spec_string.size()) { + char token = spec_string.at(token_index); + if (token == 'R') { + AxisPartitionSpec axis_spec = AxisPartitionSpec::CreateReplica(); + axis_specs.push_back(axis_spec); + ++token_index; + ++dim_index; + } else if (token == 'S') { + std::stringstream ss; + // Next should be "[". + ++token_index; + char left_bracket = spec_string.at(token_index); + ORT_ENFORCE(left_bracket == '[', "Invalid partition token: ", left_bracket, " in ", spec_string); + // Move to digit part. + ++token_index; + while (spec_string.at(token_index) != ']') { + // Now token_index should points to the first digit of + // axis index. + char digit = spec_string.at(token_index); + ORT_ENFORCE(std::isdigit(digit), "Invalid partition token: ", token, " in ", spec_string); + ss << digit; + // Loaded a digit. Go to next token. + ++token_index; + } + int device_mesh_index = 0; + ss >> device_mesh_index; + AxisPartitionSpec axis_spec = AxisPartitionSpec::CreateShard(device_mesh_index); + axis_specs.push_back(axis_spec); + // Skip "]". + char right_bracket = spec_string.at(token_index); + ORT_ENFORCE(right_bracket == ']', "Invalid partition token: ", token, " in ", spec_string); + ++token_index; + } else { + throw std::invalid_argument("Invalid partition token: " + token); + } + } + DeviceMesh device_mesh = CreateDeviceMesh(device_mesh_shape, device_mesh_elements); + return TensorPartitionSpec::Create(axis_specs, device_mesh); +} + +TensorPartitionSpec CreateTensorShardSpec( + const DeviceMesh& device_mesh, + int64_t device_mesh_axis, + int64_t shard_axis, + int64_t tensor_rank) { + if (shard_axis < 0) { + shard_axis += tensor_rank; + } + std::vector axis_specs; + for (int64_t i = 0; i < tensor_rank; ++i) { + if (i == shard_axis) { + axis_specs.push_back(AxisPartitionSpec::CreateShard(device_mesh_axis)); + } else { + axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + return TensorPartitionSpec::Create(axis_specs, device_mesh); +} + +TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorPartitionSpec& spec) { + ORT_ENFORCE(gsl::narrow(shard_shape.NumDimensions()) == spec.Rank(), "Shard shape and spec rank mismatch."); + if (spec.HasNoShard()) { + return shard_shape; + } + TensorShape shape(shard_shape); + const int64_t axis = spec.GetPartitionAxis(); + shape[axis] *= spec.GetUniqueDeviceCount(axis); + return shape; +} + +TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpec& spec) { + ORT_ENFORCE(gsl::narrow(shape.NumDimensions()) == spec.Rank(), "Shape and spec rank mismatch."); + TensorShape shard_shape(shape); + if (spec.HasNoShard()) { + return shard_shape; + } + const int64_t axis = spec.GetPartitionAxis(); + const int64_t unique_device_count = spec.GetUniqueDeviceCount(axis); + ORT_ENFORCE(shard_shape[axis] % unique_device_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + // If a [8, 16]-tensor is shared by device mesh [0, 1, 0, 1] along axis=1 (2nd axis), + // the local tensors on device 0 & 1 have same shape [8, 8 (from 16/2)] instead of + // [8, 4 (from 16/4)]. The reason is that + // - First, the original tensor are split into 4 sub-tensors [8, 4] along the 2nd axis. + // - The 1st and 3rd sub-tensors are concatenated along axis=1 to one tensor on device 0. + // - The 2nd and 4th sub-tensors are concatenated along axis=1 to one tensor on device 1. + shard_shape[axis] /= unique_device_count; + return shard_shape; +} + +TensorShape ComputeShardShape(const TensorShape source_shape, int64_t shard_axis, int64_t shard_count) { + if (shard_axis < 0) { + shard_axis += gsl::narrow(source_shape.NumDimensions()); + } + TensorShape shard_shape(source_shape); + ORT_ENFORCE(shard_axis < gsl::narrow(source_shape.NumDimensions()), "Shard axis must be less than the number of dimensions of the source tensor."); + ORT_ENFORCE(source_shape[shard_axis] % shard_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + shard_shape[shard_axis] = source_shape[shard_axis] / shard_count; + return shard_shape; +} + +std::tuple NormalizeShapes(const TensorShape& left, const TensorShape& right) { + if (left.NumDimensions() > right.NumDimensions()) { + std::vector right_vector(right.NumDimensions(), 0); + right.CopyDims(right_vector.data(), right.NumDimensions()); + // Fill 1's to right shape. E.g., + // left: [1, 2, 3, 4], right: [5, 6, 7] -> left: [1, 2, 3, 4], right: [1, 5, 6, 7] + right_vector.insert(right_vector.begin(), left.NumDimensions() - right.NumDimensions(), 1); + return std::make_tuple(left, TensorShape(right_vector)); + } else if (left.NumDimensions() < right.NumDimensions()) { + std::vector left_vector(left.NumDimensions(), 0); + left.CopyDims(left_vector.data(), left.NumDimensions()); + // Fill 1's to left shape. E.g., + // left: [1, 2, 3], right: [4, 5, 6, 7] -> left: [1, 2, 3, 1], right: [4, 5, 6, 7] + left_vector.insert(left_vector.begin(), right.NumDimensions() - left.NumDimensions(), 1); + return std::make_tuple(TensorShape(left_vector), TensorShape(right)); + } else { + return std::make_tuple(TensorShape(left), TensorShape(right)); + } +} + +std::tuple NormalizeTensorPartitionSpecs( + const TensorPartitionSpec& left, const TensorPartitionSpec& right) { + // TODO: Make it to modify left and right instead of returning new values. + if (left.axis_specs.size() > right.axis_specs.size()) { + auto new_right = TensorPartitionSpec::Create(right.axis_specs, right.device_mesh); + new_right.axis_specs.insert(new_right.axis_specs.begin(), left.axis_specs.size() - right.axis_specs.size(), AxisPartitionSpec::CreateReplica()); + return std::make_tuple(left, new_right); + } else if (left.axis_specs.size() < right.axis_specs.size()) { + auto new_left = TensorPartitionSpec::Create(left.axis_specs, left.device_mesh); + new_left.axis_specs.insert(new_left.axis_specs.begin(), right.axis_specs.size() - left.axis_specs.size(), AxisPartitionSpec::CreateReplica()); + return std::make_tuple(new_left, right); + } else { + return std::make_tuple(left, right); + } +} + +bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { + if (spec.HasNoShard()) { + return true; + } + if (gsl::narrow(shape.NumDimensions()) != spec.Rank()) { + return false; + } + const int64_t axis = spec.GetPartitionAxis(); + if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { + return false; + } + if (shape[axis] % spec.GetDeviceCount(axis) != 0) { + return false; + } + return true; +} + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h new file mode 100644 index 0000000000000..5abc50a61c9a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -0,0 +1,492 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +class DeviceMesh { + public: + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // Device mesh is a tensor of device indices. + // A tensor can then be partitioned along specific mesh axes. + // + // Assume we have 4 GPUs indexed by 0, 1, 2, and 3. + // Let's consider some examples. + // 1. 1D device mesh [0, 1, 2, 3]. In this case, + // device_mesh_shape is [4] and device_mesh_elements + // is [0, 1, 2, 3]. + // If we want to shard a 2-D tensor along its axis 1, the + // corresponding sharding spec is a string "RS[0]". + // 2. 2D device mesh [[0, 1], [2, 3]]. In this case, + // device_mesh_shape is [2, 2] and device_mesh_elements + // is [0, 1, 2, 3]. + // If we want to shard a 2-D tensor's + // rows along mesh axis 1 and + // columns along mesh axis 0, the + // corresponding sharding spec is a string "S[1]S[0]". + // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), + // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding + // proccess. + // - Start with a 2-D device mesh [[0, 1], [2, 3]] and + // a 2-D tensor [[5, 6], [7, 8]] + // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] + // - Split GPU mesh along axis 1 and tensor along + // axis 0 for "S[1]" in "S[1]S[0]" + // - GPU: [[0], [2]], Tensor: [[5, 6]] + // GPU: [[1], [3]], Tensor: [[7, 8]] + // - Split GPU mesh along axis 0 and tensor along + // axis 1 for "S[0]" in "S[1]S[0]" + // - GPU: [[0]], Tensor: [[5]] + // - GPU: [[2]], Tensor: [[6]] + // - GPU: [[1]], Tensor: [[7]] + // - GPU: [[3]], Tensor: [[8]] + + // Actual shape of device mesh represented by `device_mesh_elements`. + std::vector device_mesh_shape; + + // Flattened device mesh. + std::vector device_mesh_elements; + + // Helper to debug and generate error message; e.g., + // "DeviceMesh{Shape: [2,2,], Elements: [0,1,2,3,]}". + std::string ToString() const { + std::ostringstream os; + os << "DeviceMesh{Shape: ["; + for (const auto& shape : device_mesh_shape) + os << shape << ","; + os << "], Elements: ["; + for (const auto& element : device_mesh_elements) + os << element << ","; + os << "]}"; + return os.str(); + } + + // Call this in GDB to visualize the mesh. + void Print() const { + std::cout << ToString() << std::endl; + } + + static DeviceMesh Create1D(std::vector device_mesh_elements, size_t repeats = 1) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape.push_back(device_mesh_elements.size() * repeats); + for (size_t i = 0; i < repeats; ++i) { + device_mesh.device_mesh_elements.insert( + device_mesh.device_mesh_elements.end(), + device_mesh_elements.begin(), + device_mesh_elements.end()); + } + return device_mesh; + } + + // If the two meshes have the same shape and elements, return true. + // Otherwise, return false. + bool operator==(const DeviceMesh& other) const { + if (device_mesh_shape.size() != other.device_mesh_shape.size() || + device_mesh_elements.size() != other.device_mesh_elements.size()) { + return false; + } + + for (size_t i = 0; i < device_mesh_elements.size(); ++i) { + if (device_mesh_elements.at(i) != other.device_mesh_elements.at(i)) { + return false; + } + } + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (device_mesh_shape.at(i) != other.device_mesh_shape.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const DeviceMesh& other) const { + return !(*this == other); + } +}; + +class AxisPartitionSpec { + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // This class is the in-memory representation of + // 1. if a tensor is sharded or not (aka replica), and + // 2. which tensor axis is shard by which device mesh axis. + // Let's consider sharding 2-D tensor along column axis on + // device mesh [0, 1] as an example. + // The required sharding spec RS[0] can be represented by + // - AxisPartitionSpec(Condition::Replica, -1) + // - AxisPartitionSpec(Condition::Shard, 0) + public: + // Status of a tensor axis. + // A tensor axis can be either sharded or replicated + // along a device mesh axis. + enum class Condition { Replica, + Shard }; + + // This field tells if a tensor axis is sharded or not. + Condition cond; + + // If a tensor axis is sharded, this field tells which device + // mesh axis to distribute the shards along. + // If a tensor axis is not sharded, this field is ignored. + int device_mesh_axis; + + // A helper to construct a replica spec for a tensor axis. + static AxisPartitionSpec CreateReplica() { + return AxisPartitionSpec(Condition::Replica, -1); + } + + // A helper to construct a sharding spec for a tensor axis. + // This tensor axis is sharded along `device_mesh_axis` in device mesh. + static AxisPartitionSpec CreateShard(int device_mesh_axis) { + return AxisPartitionSpec(Condition::Shard, device_mesh_axis); + } + + static AxisPartitionSpec CreateCopy(const AxisPartitionSpec& spec) { + return AxisPartitionSpec(spec.cond, spec.device_mesh_axis); + } + + // A normal ctor. + // TODO(wechi): Consider to hide it and revise the `public` members/functions + // exposed to the user. + AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : cond(cond_), device_mesh_axis(device_mesh_axis_) {} + + // Helper to debug and generate error message; e.g., + // "RS[0]". + std::string ToString() const { + std::ostringstream os; + os << (cond == Condition::Replica ? "R" : "S"); + if (cond == Condition::Shard) os << "[" << device_mesh_axis << "]"; + return os.str(); + } + + // Call this in GDB to visualize the spec. + void Print() const { + std::cout << ToString() << std::endl; + } + + bool operator==(const AxisPartitionSpec& other) const { + return cond == other.cond && device_mesh_axis == other.device_mesh_axis; + } + + bool operator!=(const AxisPartitionSpec& other) const { + return !(*this == other); + } +}; + +// Return true if `axis` is a valid axis index for a tensor of rank `rank`. +// Negative `axis` is allowed (e.g., -1 for the last axis). +void ValidateAxisIndex(const int64_t axis, const int64_t rank); + +class TensorPartitionSpec { + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // TensorPartitionSpec holds a collection of AxisPartitionSpec and an + // associated DeviceMesh. It is responsible for determining how a tensor + // should be partitioned across a device mesh. + // + // Example 1: RS[0] + // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects. + // - The first object is a Replica, denoting that the first axis of the tensor is + // not sharded but is instead replicated. + // - The second object is a Shard along the 0-th axis of the device mesh. It denotes + // that the second axis of the tensor is sharded along the first axis of the + // device mesh. + // + // Example 2: S[0]RR + // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects. + // - The first object is a Shard along the 0-th axis of the device mesh, indicating + // that the first axis of the tensor is sharded along the first axis of the + // device mesh. + // - The second and third objects are Replicas, indicating that the second and third + // axes of the tensor are not sharded but are instead replicated. + public: + // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor, + // axis_specs[0] is for row axis and axis_specs[1] is for + // column axis. axis_specs[i].device_mesh_axis = j means that + // tensor axis i is sharded along device mesh axis j. + std::vector axis_specs; + + // device_mesh: DeviceMesh for sharding the associated tensor. + // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment. + DeviceMesh device_mesh; + + // Replacement of ctor. + static TensorPartitionSpec Create( + const std::vector& axis_specs, const DeviceMesh& device_mesh) { + TensorPartitionSpec spec; + spec.axis_specs = axis_specs; + spec.device_mesh = device_mesh; + return spec; + } + + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const TensorPartitionSpec& spec) { + TensorPartitionSpec new_spec = spec; + new_spec.axis_specs[spec.GetPartitionAxis()] = AxisPartitionSpec::CreateReplica(); + return new_spec; + } + + // TODO(wechi): Create a halper to copy-construct a new spec with different sharding axis. + // static TensorPartitionSpec CreateReshard( + // const TensorPartitionSpec& spec, int64_t new_shard_axis) { + // } + + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const size_t rank, const DeviceMesh& device_mesh) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateOneTensorAxisOneDeviceMeshAxisSharding( + const size_t rank, const DeviceMesh& device_mesh, const size_t tensor_axis, const size_t device_mesh_axis) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + axis_specs[tensor_axis] = AxisPartitionSpec::CreateShard(device_mesh_axis); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateByDropAxes( + const TensorPartitionSpec& spec, const std::vector& axes_to_drop) { + std::vector axis_specs; + for (size_t i = 0; i < spec.axis_specs.size(); ++i) { + if (std::find(axes_to_drop.begin(), axes_to_drop.end(), i) != axes_to_drop.end()) { + // This axis, i, is in axes_to_drop. Let's not copy its spec. + continue; + } + axis_specs.push_back(spec.axis_specs[i]); + } + return TensorPartitionSpec::Create(axis_specs, spec.device_mesh); + } + + static TensorPartitionSpec CreateByInsertOneAxis( + const TensorPartitionSpec& spec, + const size_t axis_to_insert) { + std::vector axis_specs(spec.axis_specs); + axis_specs.insert(axis_specs.begin() + axis_to_insert, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, spec.device_mesh); + } + + // Helper to debug and generate error message; e.g., + // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". + std::string ToString() const { + std::ostringstream os; + os << "TensorPartitionSpec{"; + for (const auto& spec : axis_specs) + os << spec.ToString(); + os << ", DeviceMesh: " << device_mesh.ToString() << "}"; + return os.str(); + } + + // Call this in GDB to visualize the spec. + void Print() const { + std::cout << ToString() << std::endl; + } + + // Return true if at least one tensor axis is sharded. + // Otherwise, return false. + bool HasShard() const { + for (const auto& spec : axis_specs) + if (spec.cond == AxisPartitionSpec::Condition::Shard) return true; + return false; + } + + // Return true if no tensor axis is sharded. + // Otherwise, return false. + bool HasNoShard() const { + return !HasShard(); + } + + // Return true if the only sharded tensor axis is `axis`. + // Otherwise, return false. + bool OnlyShardAxis(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + if (axis < 0) { + axis += Rank(); + } + bool answer = true; + for (int64_t i = 0; i < Rank(); ++i) { + if (i == axis && axis_specs[i].cond != AxisPartitionSpec::Condition::Shard) { + answer = false; + } else if (i != axis && axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + answer = false; + } + } + return answer; + } + + // Rank of the owing tensor of this spec. + int64_t Rank() const { + return gsl::narrow(axis_specs.size()); + } + + // Return the number of sharded tensor axes. + // Currently we only support one sharded tensor axis, so + // we may assert the returned value is 1 in related APIs. + int64_t CountShardingAxes() const { + int64_t count = 0; + for (const auto& spec : axis_specs) + if (spec.cond == AxisPartitionSpec::Condition::Shard) count++; + return count; + } + + // Return the AxisPartitionSpec for `axis`-th tensor axis. + const AxisPartitionSpec& GetAxisSpec(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + if (axis < 0) { + axis += Rank(); + } + return axis_specs.at(axis); + } + + // Get the first sharded tensor axis' sharding spec. + const AxisPartitionSpec& GetPartitionAxisSpec() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + return GetAxisSpec(GetPartitionAxis()); + } + + // Get the first sharded tensor axis' index. + // E.g., spec "RS[0]" should return 1, spec "S[0]R" should return 0, spec "RR" should return -1. + // Returned value -1 means no sharded tensor axis. + int64_t GetPartitionAxis() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + for (int64_t i = 0; i < gsl::narrow(axis_specs.size()); ++i) { + if (axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + return i; + } + } + return -1; + } + + // Similarly to GetPartitionAxis(), but returns the negative index of the first sharded tensor axis. + // E.g., spec "RS[0]" should return -1, spec "S[0]R" should return -2, and spec "RR" should return 0. + // Returned value 0 means no sharded tensor axis. + int64_t GetNegativePartitionAxis() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + for (int64_t i = 0; i < gsl::narrow(axis_specs.size()); ++i) { + if (axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + return i - axis_specs.size(); + } + } + return 0; + } + + // Return the number of shards along the first sharded tensor axis. + // This value matches the number of devices along the associated mesh axis. + // Return 1 if there is no sharding. + int64_t GetDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); + } + } + + // Similar to GetDeviceCount(), but returns the number of unique devices + // along the first sharded tensor axis. + int64_t GetUniqueDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + std::set device_ids( + device_mesh.device_mesh_elements.begin(), + device_mesh.device_mesh_elements.end()); + return device_ids.size(); + } + } + + bool operator==(const TensorPartitionSpec& other) const { + if (axis_specs.size() != other.axis_specs.size()) { + return false; + } + for (size_t i = 0; i < axis_specs.size(); ++i) { + if (!(axis_specs.at(i) == other.axis_specs.at(i))) { + return false; + } + } + return device_mesh == other.device_mesh; + } + + bool operator!=(const TensorPartitionSpec& other) const { + return !(*this == other); + } +}; + +// Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. +std::vector ParseStringAsInt64Vector(const std::string& str); + +DeviceMesh CreateDeviceMesh( + std::vector device_mesh_shape, + std::vector device_mesh_elements); + +TensorPartitionSpec CreateTensorPartitionSpec( + std::string spec_string, + std::vector device_mesh_shape, + std::vector device_mesh_elements); + +TensorPartitionSpec CreateTensorShardSpec( + const DeviceMesh& device_mesh, + int64_t device_mesh_axis, + int64_t shard_axis, + int64_t tensor_rank); + +// Return the shape of the original tensor before sharding. +// E.g., assume tensor shard's shape is [5, 7] and sharding spec is "S[0]R" +// with 1-D device mesh [0, 1, 2]. +// This function returns [15, 7]. +// +// `shard_shape`: the shape of a shard. +// `spec`: the sharding spec of the original tensor. +TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorPartitionSpec& spec); + +// Return the shape of a shard. +// E.g., assume tensor's shape is [15, 7] and sharding spec is "S[0]R" +// with 1-D device mesh [0, 1, 2]. +// This function returns [5, 7]. +// +// `shape`: the shape of the original tensor. +// `spec`: the sharding spec of the original tensor. +TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpec& spec); + +// Similarly to ComputeShardShape(), but takes a shard axis and counts of all tensor shards +// instead of a spec. +TensorShape ComputeShardShape(const TensorShape source_shape, int64_t shard_axis, int64_t shard_count); + +// Prepend 1's to `shape` to make `left` and `right` have the same rank. +// E.g., if `left` is [3, 7] and `right` is [5, 6, 7], this function returns [1, 3, 7] and [5, 6, 7]. +std::tuple NormalizeShapes(const TensorShape& left, const TensorShape& right); + +// Prepend `R` (aks replicating axis) to `spec` to make `left` and `right` have the same rank. +// E.g., if `left` is S[0]R and `right` is `RRR`, this function returns `RS[0]R` and `RRR`. +std::tuple NormalizeTensorPartitionSpecs( + const TensorPartitionSpec& left, const TensorPartitionSpec& right); + +// Return true if `shape` can be sharded according to `spec`. +// Otherwise, return false. +// Note that an axis is shardable along a device mesh axis only if +// the dimension of the axis is divisible by the number of devices along the device mesh axis. +bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h b/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h index 6f7a04d059034..a768b2a7d8a24 100644 --- a/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h +++ b/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -10,12 +11,12 @@ namespace contrib { namespace cuda { template -class ConvTransposeWithDynamicPads : public ::onnxruntime::cuda::ConvTranspose { +class ConvTransposeWithDynamicPads : public ::onnxruntime::cuda::ConvTranspose { public: - ConvTransposeWithDynamicPads(const OpKernelInfo& info) : ::onnxruntime::cuda::ConvTranspose(info) {} + ConvTransposeWithDynamicPads(const OpKernelInfo& info) : ::onnxruntime::cuda::ConvTranspose(info) {} Status ComputeInternal(OpKernelContext* context) const override { - return ::onnxruntime::cuda::ConvTranspose::DoConvTranspose(context, true); + return ::onnxruntime::cuda::ConvTranspose::DoConvTranspose(context, true); } }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 86c1cb93e8b6f..108eea1a73fe9 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -65,12 +65,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -89,10 +93,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -112,7 +119,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); @@ -134,6 +148,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -145,10 +160,41 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather); #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#if defined(ORT_USE_NCCL) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif template <> @@ -212,12 +258,16 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -236,10 +286,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -259,6 +312,11 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -270,6 +328,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility @@ -287,6 +347,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, @@ -298,10 +359,41 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#if defined(ORT_USE_NCCL) BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc index a38dfd34cc977..274bc9a730d87 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -44,11 +44,6 @@ Status BiasAdd::ComputeInternal(OpKernelContext* context) const { "The input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels should be 320, 640 or 1280, got ", input_dims[2]); - } - const Tensor* bias = context->Input(1); const auto& bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu index 2983cc99e30b1..8e8068b5e56ca 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -42,6 +42,17 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, } } +template +__global__ void BiasAddLargeKernel( + int32_t const ld, const T* input, const T* bias, const T* residual, T* output) { + int32_t const offset = blockIdx.x * ld; + + for (int32_t i = threadIdx.x; i < ld; i += TPB) { + int32_t const base_offset = offset + i; + output[base_offset] = input[base_offset] + bias[i] + residual[base_offset]; + } +} + template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); @@ -52,19 +63,19 @@ template __global__ void BiasAddKernel(half const*, half const* template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, T const* input, T const* bias, T const* residual, T* output) { - constexpr int32_t TPB = 320; // thread per block switch (num_channels) { case 320: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 640: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 1280: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; default: - ORT_NOT_IMPLEMENTED("Not implemented"); + BiasAddLargeKernel<<>>(num_channels, input, bias, residual, output); + break; } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index 2b13cdbd803ef..cb02bd8541623 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,9 +39,13 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + if (input_dims[2] != 2560 && + input_dims[2] != 5120 && + input_dims[2] != 6144 && + input_dims[2] != 10240 && + input_dims[2] != 12288) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + "hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]); } const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 19e05a9573f7c..3ae9611d4dfad 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h case 5120: (biasSplitGeluKernel)<<>>(input, bias, output); break; + case 3072: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 6144: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; default: ORT_NOT_IMPLEMENTED("Not implemented"); } @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, float const* input, float const* bias, float* output); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc new file mode 100644 index 0000000000000..6cdccdb1becb1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/math/gemm.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cuda/math/gemm_float8.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ + GemmFloat8); + +REGISTER_KERNEL() + +GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto& device_prop = GetDeviceProp(); + sm_count_ = device_prop.multiProcessorCount; + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + +#if (CUDA_VERSION < 12000) + ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); +#endif + + std::string stemp = info.GetAttrOrDefault("activation", "NONE"); + if (stemp == "NONE") { + epilogue_ = CUBLASLT_EPILOGUE_DEFAULT; + } else if (stemp == "RELU") { + epilogue_ = CUBLASLT_EPILOGUE_RELU; + } else if (stemp == "GELU") { + epilogue_ = CUBLASLT_EPILOGUE_GELU; + } else { + ORT_THROW("Unexpected value for activation: '", stemp, "'."); + } +} + +Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const { + GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({})); + if (!helper.State().IsOK()) + return helper.State(); + + M = gsl::narrow_cast(helper.M()); + N = gsl::narrow_cast(helper.N()); + K = gsl::narrow_cast(helper.K()); + return helper.State(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu new file mode 100644 index 0000000000000..56b541f5256bf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -0,0 +1,411 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// The operator calls function 'cublasLtMatmul' +// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul). +// It lets the function checks what configuration is valid or not. If not, the error message +// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides +// information on what attribute or type must be modified. +// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0 +// for beta != 0. + +#include +#include +#include +#include "contrib_ops/cuda/math/gemm_float8.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// It must exist somewhere already. +int32_t TypeSize(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return 4; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; +#if !defined(DISABLE_FLOAT8_TYPES) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return 1; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + +void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape, + int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const { + int m_idx = transA_ ? 1 : 0; + int k_idx = 1 - m_idx; + int n_idx = transB_ ? 0 : 1; + + M = static_cast(a_shape[m_idx]); + K = static_cast(a_shape[k_idx]); + N = static_cast(b_shape[n_idx]); + lda = static_cast(a_shape[1]); + ldb = static_cast(b_shape[1]); + ldd = static_cast(b_shape[n_idx]); +} + +template +int32_t GetTypeAndShape(const TValue* input, + TensorShape& shape, + bool swap = false) { + shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() == 2); + if (swap) { + std::swap(shape[0], shape[1]); + } + return input->GetElementType(); +} + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_A = nullptr; + const Tensor* input_B = nullptr; + const Tensor* input_C = nullptr; + const Tensor* scale_A = nullptr; + const Tensor* scale_B = nullptr; + const Tensor* scale_Y = nullptr; + bool has_scales = false; + bool has_bias = false; + int n_inputs = ctx->InputCount(); + + input_A = ctx->Input(0); + input_B = ctx->Input(1); + if (n_inputs == 3) { + input_C = ctx->Input(2); + has_bias = true; + } else if (n_inputs > 3) { + ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, "."); + has_scales = true; + scale_A = ctx->Input(3); + scale_B = ctx->Input(4); + scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5); + ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + if (ctx->Input(2) != nullptr) { + input_C = ctx->Input(2); + has_bias = true; + ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype."); + } + } + + auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) + bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + if (!is_float8) +#endif + return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) + return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +#endif +} + +Status GemmFloat8::ComputeRowMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C, + dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_, + input_A->DataRaw(), input_B->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true); +} + +Status GemmFloat8::ComputeColMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + std::swap(shape_A[0], shape_A[1]); + std::swap(shape_B[0], shape_B[1]); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C, + dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_, + input_B->DataRaw(), input_A->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false); +} + +Status GemmFloat8::ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const { + cudaStream_t stream = Stream(ctx); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) +#if CUDA_VERSION < 11080 +#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES. +#endif + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; +#endif + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + + if (sm_count_ != 0) { + int math_sm_count = static_cast(sm_count_); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } + + if (has_scales) { + // gemm float 8 + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); + + // float 8 +#if !defined(DISABLE_FLOAT8_TYPES) + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue_, sizeof(epilogue_)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x", + (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha_), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta_), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha_, + ", n_inputs=", n_inputs, ", A_type=", + onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, + ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". CUDA>=11.8 is required to use float 8 types."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h new file mode 100644 index 0000000000000..e84ccd55b2003 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cublas_v2.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul. +// D = alpha*(A*B) +class GemmFloat8 final : public onnxruntime::cuda::CudaKernel { + public: + GemmFloat8(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + void SetParams(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K, + int& lda, int& ldb, int& ldd) const; + Status SetCheck(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K) const; + + Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + + Status ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_b, + int32_t dtype_c, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool transa, bool transb, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const; + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t sm_count_; + int64_t dtype_; + cublasLtEpilogue_t epilogue_; + + // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available). +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h new file mode 100644 index 0000000000000..86136ea244e23 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "cutlass/device_kernel.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this configuration. + status = cudaGetLastError(); + return 0; + } + CUDA_CALL_THROW(status); + } + + int max_active_blocks = -1; + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel, + GemmKernel::kThreadCount, smem_size)); + + return max_active_blocks; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc new file mode 100644 index 0000000000000..5d4c6793ec995 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_heuristic.h" + +#include +#include +#include + +namespace ort_fastertransformer { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + default: + ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, + const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + const int k_elements_per_split = static_cast(k / split_k_factor); + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) { + std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + + std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; + + std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + + const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; +} + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) { + std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); + + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; + + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && + current_m_tile < tile_shape.m) { + continue; + } + + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("[FT Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h new file mode 100644 index 0000000000000..e70efe0503b55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ft_gemm_configs.h" + +#include +#include +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t num_experts, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h new file mode 100644 index 0000000000000..78d206bf1d9bc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +template <> +struct GELU_taylor { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float( + cutlass::constants::half() * z * + (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +namespace ort_fastertransformer { + +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue {}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling, + cutlass::FloatRoundStyle::round_to_nearest, true>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h new file mode 100644 index 0000000000000..a5faad423fad9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ort_fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h new file mode 100644 index 0000000000000..311ed323cb90c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h @@ -0,0 +1,79 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h new file mode 100644 index 0000000000000..eb33a98e4246f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct MixedGemmArchTraits {}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +// This will instantiate any HMMA tensorcore kernels for Volta. +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm70, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm75, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm80, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h new file mode 100644 index 0000000000000..bfe30b71170d8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "gemm_moe_problem_visitor.h" +#include "tile_interleaved_layout.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> : platform::true_type {}; + +// SFINAE overload for dequantizing gemm +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, + scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); +} + +// SFINAE overload for normal gemm. This completely ignores the scale parameters +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { CUTLASS_NOT_IMPLEMENTED(); } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = + problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + run_mma(mma, gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, weight_scale_ptr, + {1, problem_size.n()}, thread_idx, tb_offset_scale); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h new file mode 100644 index 0000000000000..60608f462fde5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "ft_gemm_configs.h" + +namespace ort_fastertransformer { + +enum class ActivationType { Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType }; + +template +class MoeGemmRunner { + public: + MoeGemmRunner(); + + void initialize(int sm); + + void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream); + + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + + private: + template + void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + + template + void run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 0000000000000..1d9a249db4237 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 0000000000000..7b250e6ca9060 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h new file mode 100644 index 0000000000000..66950c9b65970 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "compute_occupancy.h" +#include "epilogue_helpers.h" +#include "layout_traits_helper.h" +#include "moe_cutlass_kernel.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "cutlass_heuristic.h" +#include "moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace ort_fastertransformer { + +// ============================= Variable batched Gemm things =========================== +template +void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, const int multi_processor_count, + cudaStream_t stream, int* kernel_occupancy = nullptr) { + if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + } + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, + typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape, + typename MixedGemmArchTraits::InstructionShape, EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + if (occupancy == 0) { + ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + } + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + typename GemmGrouped::Arguments args( + num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, + gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + } +}; + +template +void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && !std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + break; + } +} + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + break; + } +} + +template +MoeGemmRunner::MoeGemmRunner() {} + +template +void MoeGemmRunner::initialize(int sm_version) { + int device{-1}; + cudaGetDevice(&device); + sm_ = sm_version; + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); +} + +template +template +void MoeGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, + const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, cudaStream_t stream, + int* occupancy) { + if (sm_ >= 70 && sm_ < 75) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else { + ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } +} + +template +template +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) { + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, + split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); + + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, chosen_config, stream); +} + +template +void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, + cudaStream_t stream) { + switch (activation_type) { + case ActivationType::Relu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Gelu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Identity: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::InvalidType: + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; + default: { + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + } + } +} + +template +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, cudaStream_t stream) { + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu new file mode 100644 index 0000000000000..398ce4ee9880f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -0,0 +1,830 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "moe_kernel.h" + +#if CUDA_VERSION >= 11000 +#include +#include +#include +#else +#include "cub/cub.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/util_type.cuh" +#endif + +namespace ort_fastertransformer { + +static constexpr int WARP_SIZE = 32; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, + int* indices, int* source_rows, int num_experts, int k) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool should_process_row = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} +#endif + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, + int* source_rows, int k) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) return; + const bool should_process_row = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk_input; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = cutlass::NumericArrayConverter; + Converter compute_type_converter; + cutlass::Array row_chunk = compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + ComputeType thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); + } + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, + int num_rows, int num_experts, int k, cudaStream_t stream) { + static constexpr unsigned long MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax + <<>>(input, finished, output, num_rows, indices, source_row, k); +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + + switch (num_experts) { + case 2: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 4: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 8: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 16: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 32: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 64: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 128: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 256: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + default: { + static constexpr int TPB = 256; + moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); + moe_top_k + <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k); + } + } +} + +// ========================== CUB Sorting things ==================================== +CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(int num_experts) + : num_experts_(num_experts), num_bits_((int)log2(num_experts) + 1) {} + +void CubKeyValueSorter::update_num_experts(int num_experts) { + num_experts_ = num_experts; + num_bits_ = (int)log2(num_experts) + 1; +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, + (int)num_key_value_pairs, 0, num_bits_); + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, + const int* values_in, int* values_out, const size_t num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + ORT_THROW("Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", + expected_ws_size, " but got problem size ", workspace_size, "\n"); + } + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, + (int)num_key_value_pairs, 0, num_bits_, stream); +} + +// ============================== Infer GEMM sizes ================================= +__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. +// Assumes we want to perform output = matmul(inputs, experts) + bias +__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +template +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + moe_gemm_runner_.initialize(sm_version); +} + +template +size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size, + const int inter_size, int num_experts, + int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + int num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts)); + } + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + total_ws_bytes += num_softmax_outs * sizeof(T); + const int bytes_for_fc1_result = interbuf_size * sizeof(T); + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows))); + sorter_.update_num_experts(num_experts); + + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result)); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + return total_ws_bytes; +} + +template +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows, + const int hidden_size, const int inter_size, + int num_experts, int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + // const int num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); + + source_rows_ = (int*)ws_ptr; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + permuted_data_ = (T*)(permuted_experts_ + num_moe_inputs); + + total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); + + fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = (T*)(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + static constexpr bool scales_required = + std::is_same::value || std::is_same::value; + + if (scales_required) { + if (fc1_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); + } else if (fc2_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for second matmul is a null pointer"); + } + } else { + if (fc1_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + } else if (fc2_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } + } + + configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); + topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, stream); + + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); + sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, + permuted_rows_, k * num_rows, stream); + + initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, + stream); + + const int expanded_active_expert_rows = k * active_rows; + compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); + + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, + total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, + num_experts, fc1_activation_type, stream); + + moe_gemm_runner_.moe_gemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, + expanded_active_expert_rows, hidden_size, inter_size, num_experts, stream); +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream) { + run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, + fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, k, workspace_ptr, + fc2_result, nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, + expert_for_source_row, stream); +} + +template +void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, + const int total_indices, + int num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream) { + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, + total_rows_before_expert); +} + +// ========================== Permutation things ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to +// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input +// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus +// of the expanded index. + +template +__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + dest_row_ptr[tid] = source_row_ptr[tid]; + } + } +} + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream) { + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + initialize_moe_routing_kernel + <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); +} + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__global__ void finalize_moe_routing_kernel(const T*, T*, const T*, const T*, const T*, const T*, const int*, + const int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int cols, int k) { + const int original_row = blockIdx.x; + int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + const T* skip_1_row_ptr = nullptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; + } + const T* skip_2_row_ptr = nullptr; + if (RESIDUAL_NUM == 2) { + skip_2_row_ptr = skip_2 + original_row * cols; + } + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output; + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } else if (RESIDUAL_NUM == 1) { + thread_output = skip_1_row_ptr[tid]; + } else if (RESIDUAL_NUM == 2) { + thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} +#endif + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel + <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + if (skip_2 == nullptr) { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } else { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } +} + +// ========================= TopK Softmax specializations =========================== +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, + int, int, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, + int, int, cudaStream_t); + +// ==================== Variable batched GEMM specializations ================================== +template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; + +// ===================== Specializations for init routing ========================= +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, + int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, + int, int, cudaStream_t); + +// ==================== Specializations for final routing =================================== +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, + int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const float*, const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, + const half*, const int*, const int*, int, int, int, + cudaStream_t); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h new file mode 100644 index 0000000000000..5cefe4fa5dc47 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "moe_gemm_kernels.h" +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +static inline size_t pad_to_multiple_of_16(size_t input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +/* + Launches the topk gating softmax required for the MoE layers. + + Params: + input - a [num_rows x num_experts] + finished - [num_rows] vector with 1 if the sentence at this row is done translating and 0 otherwise. + output - a buffer of shape [num_rows x k] containing the top-k values of the softmax for each row. + indices - a matrix of shape [num_rows x k] containing the top-k experts each row should get routed to. + source_rows - a matrix of shape [num_rows x k] used internally for permuting. source_rows[row][k] = k * num_rows + + row. It is constructed like this so we can track where each of the original rows end up in order to perform the + "k-way" reduction later in the routing. + + num_rows - The number of rows in the matrix + num_experts - The number of expert layers present + k - k value in topk +*/ +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream); + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int num_experts); + + void update_num_experts(int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs); + + void run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, const int* values_in, + int* values_out, const size_t num_key_value_pairs, cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . +// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. +// Avoid making several duplicates of this class. +template +class CutlassMoeFCRunner { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, + cudaStream_t stream); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream); + + void compute_total_rows_before_expert(const int* sorted_indices, int total_indices, int num_experts, + int64_t* total_rows_before_expert, cudaStream_t stream); + + private: + void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + private: + CubKeyValueSorter sorter_; + MoeGemmRunner moe_gemm_runner_; + + // Pointers + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + char* sorter_ws_; + T* permuted_data_; + T* softmax_out_; + + int64_t* total_rows_before_expert_; + + T* fc1_result_; +}; + +template +class CutlassMoeFCRunner::value>> { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) { + return 0; + } +}; + +} // namespace ort_fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h new file mode 100644 index 0000000000000..00f977c615df6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor { + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_), problem_start(problem_start_) {} + }; + + struct Params { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem), + gemm_n(gemm_n), + gemm_k(gemm_k), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) {} + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { return tile_idx; } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { return problem_idx; } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { return tile_idx - problem_tile_start; } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { tile_idx += grid_size; } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { return problem_size(problem_idx); } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { return ProblemSizeHelper::tile_count(grid); } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) {} +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h new file mode 100644 index 0000000000000..3505bea24e4d9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +class ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc new file mode 100644 index 0000000000000..6f2ffe7a0cc43 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "moe.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +Status MoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc2_experts_weights = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_bias_optional = context->Input(5); + + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + const int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + const int64_t hidden_size = input_dims[input_dims.size() - 1]; + const int64_t num_experts = fc1_experts_weights_dims[0]; + const int64_t inter_size = fc1_experts_weights_dims[2]; + + // TODO: refactor to helper function. + if (fc1_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", + fc1_experts_weights_dims.size()); + } + if (fc2_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", + fc2_experts_weights_dims.size()); + } + if (fc1_experts_weights_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", + fc1_experts_weights_dims[1], " and ", hidden_size); + } + if (fc2_experts_weights_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[1] must be equal to inter_size, got ", fc2_experts_weights_dims[1], + " and ", inter_size); + } + if (fc1_experts_weights_dims[2] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[2] must be equal to inter_size, got ", fc1_experts_weights_dims[2], + " and ", inter_size); + } + if (fc2_experts_weights_dims[2] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + fc2_experts_weights_dims[2], " and ", hidden_size); + } + if (router_probs_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", + router_probs_dims.size()); + } + if (router_probs_dims[0] != num_rows) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", + router_probs_dims[0], " and ", num_rows); + } + if (router_probs_dims[1] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[1] must be equal to num_experts, got ", + router_probs_dims[1], " and ", num_experts); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); + } + if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { + const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); + const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); + if (fc1_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", + fc1_experts_bias_dims.size()); + } + if (fc2_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", + fc2_experts_bias_dims.size()); + } + if (fc1_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[0] must be equal to num_experts, got ", fc1_experts_bias_dims[0], + " and ", num_experts); + } + if (fc2_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], + " and ", num_experts); + } + if (fc1_experts_bias_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], + " and ", inter_size); + } + if (fc2_experts_bias_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], + " and ", hidden_size); + } + } + + typedef typename ToCudaType::MappedType CudaT; + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + + size_t ws_size = + moe_runner.getWorkspaceSize(static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_)); + size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + // TODO: allocate one buffer and reuse it. + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + // fc1_scales and fc2_scales are used in quantized MoE + const CudaT* fc1_scales_ptr = nullptr; + const CudaT* fc2_scales_ptr = nullptr; + + moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), + std::move(fc1_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_), + reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), + static_cast(k_), Stream(context)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h new file mode 100644 index 0000000000000..8035568693814 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class MoE final : public CudaKernel { + public: + explicit MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ort_fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ort_fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ort_fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t k_; + ort_fastertransformer::ActivationType activation_type_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu new file mode 100644 index 0000000000000..7921315ab52e1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -0,0 +1,375 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[4]; + half v0 = __uint2half_rn(values_quant & 0xF); + half v1 = __uint2half_rn((values_quant >> 4) & 0xF); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __uint2half_rn((values_quant >> 8) & 0xF); + half v3 = __uint2half_rn((values_quant >> 12) & 0xF); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + half v4 = __uint2half_rn((values_quant >> 16) & 0xF); + half v5 = __uint2half_rn((values_quant >> 20) & 0xF); + results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2; + + half v6 = __uint2half_rn((values_quant >> 24) & 0xF); + half v7 = __uint2half_rn((values_quant >> 28) & 0xF); + results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2; + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) { + float zp_adjust = -scale * zp; + output[0] = float(values_quant & 0xF) * scale + zp_adjust; + output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust; +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, + int block_size, + int blocks_per_K, + int blocks_per_threadblock, + int shift) { + int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int n_idx = block_id / blocks_per_K; + int kb_idx = block_id % blocks_per_K; + int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + T scale = *(scale_data + block_id); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + output = output + element_offset; + DequantizeEightElements(quant_value, scale, static_cast(zp), output); +} + +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + int k, + int n, + int block_size, + cudaStream_t stream) { + // k is padded and equal to block_per_K * block_size + ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); + constexpr int element_per_thread = 8; + int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int blocks_per_K = k / block_size; + int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); + int shift = static_cast(log2f(float(block_size))); + + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + blocks_per_K, + blocks_per_threadblock, + shift); + + return Status::OK(); +} + +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D; +}; + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ +void dequantizeThread(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // !! 4b specific code + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + const auto meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // quantized matrix is stored in column major, packed by column + const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + + int32_t r_blk_idx = static_cast(block_idx % thrd_row_blks); + int32_t c_blk_idx = static_cast(block_idx / thrd_row_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets + const ElementT scale_buf[2] = { + scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], + ((r/QuantBlk::kRow) < (meta_rows - 1)) + ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] + : static_cast(0.0f)}; + const uint8_t zp_pair = (zero_points == nullptr) + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; + const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), + (-scale_buf[1]) * static_cast(zp_buf[1])}; + + for (int32_t j = c; j < c_end; ++j) { + const uint8_t* q_ptr = weights + j * q_rows; + for (int32_t i = r; i < (r_end - 1); i += 2) { + const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; + + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[i / 2]; + + if constexpr (std::is_same::value) { + half2 scale_half2 = {scale0, scale1}; + half2 zp_adjust2 = {adjust0, adjust1}; + + half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; + half2 results = v * scale_half2 + zp_adjust2; + + dst[j * rows + i] = results.x; + dst[j * rows + (i + 1)] = results.y; + } else { + static_assert(std::is_same::value, "Only float and half are supported!"); + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; + } + } + + if ((r_end & 1) && (r_end > r)) { + const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[(r_end - 1) / 2]; + const uint8_t vi0 = vi & 0xf; + + dst[j * rows + (r_end - 1)] = static_cast(vi0) * scale0 + adjust0; + } + } +} + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thrd_row_blks); +} + + +template +Status +DequantizeBlockwise4b( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream) { + switch (block_size) { + case 16: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for blockwise quantization."); + } +} + +template +Status DequantizeBlockwise4b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template +Status DequantizeBlockwise4b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh new file mode 100644 index 0000000000000..f9c09c55fd893 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + + +/** + * @brief Dequantize a block-wise quantized matrix, and store the result in a + * column major matrix for use in subsequent GEMM. This implementation supports + * columnwise and rowwise block orientation. + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] qelements pointer to the quantized elements, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] block_size size of the quantized block + * @param[in] columnwise whether the quantized matrix is columnwise or rowwise quantized + * @param[in] rows + * @param[in] columns + */ +template +Status DequantizeBlockwise4b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 0000000000000..2f74dd41f0759 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]; + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); + vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +void CallkDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template <> +Status DequantizeBnb4( + const BFloat16* quant_map, + BFloat16* output, + const uint8_t* quant_data, + const BFloat16* absmax, + int block_size, + int numel, + cudaStream_t stream) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); + #else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + #endif + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 0000000000000..a0d38c9853cd6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +// templated scalar multiply function +template +__device__ inline T ScalarMul(T a, T b); + +template <> +__device__ inline float ScalarMul(float a, float b) { + return a * b; +} + +template <> +__device__ inline half ScalarMul(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; + #else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); + #endif +} + +template <> +__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) { + return a * b; +} +#endif + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..bbcb7de99781f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; + bool is_training_mode_; + bool transB_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + const bool transb = transB_; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode + && TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + CUBLAS_OP_N, // transA + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + helper.Ldb(transb), // ldb + reinterpret_cast(a_data), + helper.Lda(transa), // lda + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 0000000000000..098e3618beddd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline float ScalarMulFloatOut(T a, T b); + +template <> +__device__ inline float ScalarMulFloatOut(float a, float b) { + return a * b; +} + +template <> +__device__ inline float ScalarMulFloatOut(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); + #else + // half multiplication not supported + return static_cast(a) * static_cast(b); + #endif +} + +template <> +__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) { + return static_cast(a * b); +} +#endif + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = absmax[absidx]; + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); + local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + local_C += ScalarMulFloatOut(local_A[k], local_B[k]); + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +bool CheckDims(int m, int k, int block_size) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + return true; +} + +template +void Callkgemm_4bit_inference_naive( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 32 : 16; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template <> +bool TryMatMulBnb4( + const BFloat16* quant_map, + BFloat16* output, + const BFloat16* a_data, + const uint8_t* b_data_quant, + const BFloat16* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); + #else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + #endif + + return true; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 0000000000000..743234282fbf3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..5b0e61e197014 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulFp32Q4 operator, it is basically +// matmul float32 with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "matmul_nbits.cuh" +#include "dequantize_blockwise.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +template +Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + if (!is_4bit_done) { + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + // column-wise block + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } +#if 0 + cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); + T* b_data_cpu = new T[K_ * N_]; + cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); + delete[] b_data_cpu; +#endif + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu new file mode 100644 index 0000000000000..f2600a506285d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -0,0 +1,229 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + uint4 vec_a = *(reinterpret_cast(a)); + + half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)); + half2 v0 = element01 * scale_half2 + zp_adjust2; + + half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)); + half2 v1 = element23 * scale_half2 + zp_adjust2; + + half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)); + half2 v2 = element45 * scale_half2 + zp_adjust2; + + half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)); + half2 v3 = element67 * scale_half2 + zp_adjust2; + + v0 = v0 * (*(reinterpret_cast(&(vec_a.x)))); + v1 = v1 * (*(reinterpret_cast(&(vec_a.y)))); + v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0; + v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1; + v3 = v2 + v3; + return float(v3.x) + float(v3.y); +} + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + float zp_adjust = -scale * zp; + float v0 = float(values_quant & 0xF) * scale + zp_adjust; + float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + float v2 = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + float v3 = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + float v4 = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; + + v0 = v0 * a_vec_0.x; + v1 = v1 * a_vec_0.y; + v2 = v2 * a_vec_0.z; + v3 = v3 * a_vec_0.w; + v4 = v4 * a_vec_1.x + v0; + v5 = v5 * a_vec_1.y + v1; + v6 = v6 * a_vec_1.z + v2; + v7 = v7 * a_vec_1.w + v3; + return v4 + v5 + v6 + v7; +} + +constexpr int kColsPerThreadBlock = 8; +constexpr int kWarpSize = 32; + +// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N) +// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob] +// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1) +// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], +// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) +template +__global__ void MatMulFloatInt4Kernel( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int blocks_per_K) { + int n_block_id = blockIdx.x; + int m_id = blockIdx.y; + int lane_id = threadIdx.x; + int warp_id = threadIdx.y; + int n_id = n_block_id * kColsPerThreadBlock + warp_id; + int thread_id = warp_id * kWarpSize + lane_id; + constexpr int k_per_iter = 256; + int k_iter = k / k_per_iter; + + // blocks_per_k is the number of scales and zero points on the k dim + const int b_zp_k = (blocks_per_K + 1)/ 2; + + extern __shared__ char shared_buffer[]; + + // load scale to shared buffer + T* b_scale_vec = (T*)shared_buffer; + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); + int offset = n_block_id * kColsPerThreadBlock * blocks_per_K; + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + b_scale_vec[i] = scales_data[offset + i]; + } + + int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; + for (int i = thread_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[zp_offset + i] : uint8_t(0x88); + } + __syncthreads(); + + a_data += m_id * k; + b_data_quant += n_id * blocks_per_K * (block_size / 2); + + const int scale_col_offset = warp_id * blocks_per_K; + const int zp_col_offset = warp_id * b_zp_k; + + float sum = 0.f; + int k_id = 0; + for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + (t_k >> 1))); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // handle reminder + if (k_id + lane_id * 8 < k) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // warp reduction + for (int i = 16; i > 0; i = i / 2) { + sum += __shfl_down_sync(0xffffffff, sum, i); + } + + if (lane_id == 0) { + output[m_id * n + n_id] = sum; + } +} + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream) { + if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { + return false; + } + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + dim3 threads(kWarpSize, kColsPerThreadBlock); + int blocks_per_K = (k + block_size - 1) / block_size; + int blocks_per_thread_block = blocks_per_K * kColsPerThreadBlock; + int shared_mem_size = sizeof(T) * blocks_per_thread_block + blocks_per_thread_block / 2; + if (shared_mem_size > shared_mem_per_block) { + return false; + } + + if (16 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (32 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (64 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (128 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else { + ORT_THROW("block size ", block_size, " is not supported"); + } + + return true; +} + +template bool TryMatMul4Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul4Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh new file mode 100644 index 0000000000000..9ccbe4c4d97a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc new file mode 100644 index 0000000000000..381316f605fc9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping.h" +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + DynamicTimeWarping, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("F", DataTypeImpl::GetTensorType()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + DynamicTimeWarping); + +Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); + + const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]); + const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]); + size_t max_index_len = 0; + + size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + + size_t result_len = 0; + ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( + this->Stream(ctx), this->GetDeviceProp(), 1, rows, cols, + input_tensor.Data(), buffer.get(), result_len)); + + Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)}); + + return CUDA_CALL(cudaMemcpy2DAsync( + output_tensor->MutableData(), result_len * sizeof(int32_t), + buffer.get() + ((max_index_len - result_len) * sizeof(int32_t)), max_index_len * sizeof(int32_t), + result_len * sizeof(int32_t), 2, + cudaMemcpyDeviceToDevice, this->Stream(ctx))); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h new file mode 100644 index 0000000000000..3083e19aff6f2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class DynamicTimeWarping final : public CudaKernel { + public: + DynamicTimeWarping(const OpKernelInfo& info) : CudaKernel(info) {} + + ~DynamicTimeWarping() = default; + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu new file mode 100644 index 0000000000000..7c3f2963207e6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__global__ void DynamicTimeWarpingInitCost(float* cost_buffer, int8_t* trace_buffer, size_t cols_plus_1) { + int r = blockIdx.x; + cost_buffer += cols_plus_1 * r; + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + cost_buffer[i] = FLT_MAX; + } + if (r == 0) { + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + trace_buffer[i] = 2; + } + } + if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; + if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; +} + +__global__ void DynamicTimeWarpingKernel( + size_t rows, + size_t cols, + size_t max_index_len, + const float* input, + float* cost_buffer, + int8_t* trace_buffer, + int32_t* result_buffer, + size_t* result_len_device +) { + const int diag_max = static_cast(rows + cols); + for (int d = 1; d <= diag_max; d++) { + for (int c = threadIdx.x + 1; c <= cols; c += blockDim.x) { + int r = d - c; + if (r >= 1 && r <= rows) { + int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] + const float c0 = cost_buffer[cost_idx]; + const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] + const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] + + float cost; + int8_t t; + if (c0 < c1 && c0 < c2) { + cost = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + cost = c1; + t = 1; + } else { + cost = c2; + t = 2; + } + cost_idx += ((cols + 1) + 1); + cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; + trace_buffer[cost_idx] = t; + } + } + __syncthreads(); + } + + //back tracing, reverse append to result buffer + if (threadIdx.x == 0) { + int r = rows - 1; + int c = cols - 1; + int pos = static_cast(max_index_len); // reverse put + while (r >= 0 && c >= 0) { + --pos; + result_buffer[pos] = r; + result_buffer[max_index_len + pos] = c; + const int trace_index = (r + 1) * (cols + 1) + (c + 1); + int8_t t = trace_buffer[trace_index]; + switch (t) { + case 0: r -= 1; c -= 1; break; + case 1: r -= 1; break; + default: c -= 1; break; + } + } + *result_len_device = max_index_len - static_cast(pos); + } +} + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len) { + max_index_len = rows + cols + 1; + size_t cost_buffer_size = ((rows + 1) * (cols + 1)); + return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays + sizeof(int64_t) + // final index array length + batch* cost_buffer_size * sizeof(float) + // cost buffer + batch* cost_buffer_size * sizeof(int8_t); // trace buffer +} + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len +) { + ORT_ENFORCE(batch == 1); + size_t max_index_len = rows + cols + 1; + int32_t* result_buffer = (int32_t*)buffer; + size_t* result_len_device_buf = (size_t*)(result_buffer + (batch * max_index_len * 2)); + float* cost_buffer = (float*)(result_len_device_buf + 1); + int8_t* trace_buffer = (int8_t*)(cost_buffer + ((rows + 1) * (cols + 1))); + + dim3 block(device_prop.maxThreadsPerBlock); + dim3 grid_init((unsigned)SafeInt(rows + 1), (unsigned)SafeInt(batch)); + DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols+1); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + dim3 grid(1, (unsigned)SafeInt(batch)); + DynamicTimeWarpingKernel<<>>( + rows, + cols, + max_index_len, + input, + cost_buffer, + trace_buffer, + result_buffer, + result_len_device_buf); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + return CUDA_CALL(cudaStreamSynchronize(stream)); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h new file mode 100644 index 0000000000000..cb4a0dfb16807 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len); + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.cc b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc new file mode 100644 index 0000000000000..c38c8c5317f0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold.h" +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + UnfoldTensor, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + UnfoldTensor); + +Status UnfoldTensor::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + + int dim = SafeInt(HandleNegativeAxis(dim_, rank)); + ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); + ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); + + int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1LL, std::multiplies()); + int64_t tailing_dims = std::accumulate(input_dims.begin() + (dim + 1), input_dims.end(), 1LL, std::multiplies()); + + std::vector output_dims(rank + 1, 0); + std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); + output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; + output_dims.back() = size_; + TensorShape output_shape(output_dims); + Tensor* output_tensor = ctx->Output(0, output_shape); + + cudaStream_t stream = this->Stream(ctx); + const cudaDeviceProp& device_prop = this->GetDeviceProp(); + size_t element_size = input_tensor.DataType()->Size(); + return LaunchUnfoldTensor( + stream, device_prop, element_size, input_tensor.DataRaw(), output_tensor->MutableDataRaw(), + leading_dims, input_dims[dim], tailing_dims, size_, step_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.h b/onnxruntime/contrib_ops/cuda/tensor/unfold.h new file mode 100644 index 0000000000000..1717687593470 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class UnfoldTensor final : public CudaKernel { + public: + UnfoldTensor(const OpKernelInfo& info) : CudaKernel(info) { + dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL)); + step_ = SafeInt(info.GetAttrOrDefault("step", 1LL)); + ORT_ENFORCE(step_ > 0, "step must greater than zero!"); + + int64_t temp_size; + ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); + size_ = SafeInt(temp_size); + } + + ~UnfoldTensor() = default; + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int dim_; + int size_; + int step_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu new file mode 100644 index 0000000000000..a3c93ceb33c46 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void UnfoldTensorKernel( + const T* input, + T* output, + int64_t N, + int64_t unfold_size, // stride_tailing_dim_dst + int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size + int64_t stride_leading_dst, + int64_t stride_fold_dim_src, + int64_t stride_leading_src +) { + int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + const int64_t idx_leading = idx / stride_leading_dst; + int64_t n = idx % stride_leading_dst; + const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; + const int64_t idx_fold = n / stride_fold_dim_dst; + n %= stride_fold_dim_dst; + const int64_t idx_tailing = n / unfold_size; + const int64_t idx_append = n % unfold_size; + + int64_t idx_src = idx_leading * stride_leading_src + idx_fold * stride_fold_dim_src + idx_tailing + idx_append * tailing_dims_size; + output[idx] = input[idx_src]; +} + + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size +) { + int64_t TPB = device_prop.maxThreadsPerBlock; + int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; + int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; + int64_t num_blocks = (N + TPB - 1) / TPB; + + int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; + + int64_t stride_fold_dim_src = tailing_dims_size * step_size; + int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; + + dim3 block((unsigned)SafeInt(TPB)); + dim3 grid((unsigned)SafeInt(num_blocks)); + switch (element_size) { + case 1: + UnfoldTensorKernel<<>>( + (const int8_t*)input, (int8_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 2: + UnfoldTensorKernel<<>>( + (const int16_t*)input, (int16_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 4: + UnfoldTensorKernel<<>>( + (const int32_t*)input, (int32_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 8: + UnfoldTensorKernel<<>>( + (const int64_t*)input, (int64_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 16: + UnfoldTensorKernel<<>>( + (const float4*)input, (float4*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + default: + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); + } + + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h new file mode 100644 index 0000000000000..9e82dccdec23c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t tailing_dims_size, + int64_t dim_size, + int64_t unfold_size, + int64_t step_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index d18460e016444..2a90e4911f286 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -33,6 +33,28 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType()}), BeamSearch); +ONNX_OPERATOR_KERNEL_EX( + WhisperBeamSearch, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 3) // 'num_beams' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 4) // 'num_return_sequences' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 5) // 'length_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 6) // 'repetition_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + WhisperBeamSearch); + transformers::CudaTensorConsoleDumper g_cuda_dumper; BeamSearch::BeamSearch(const OpKernelInfo& info) @@ -58,7 +80,9 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::UpdateDecoderFeeds, GenerationCudaDeviceHelper::ExpandBuffer, GenerationCudaDeviceHelper::ExpandBuffer, - GenerationCudaDeviceHelper::ExpandBuffer); + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); SetConsoleDumper(&g_cuda_dumper); @@ -87,6 +111,60 @@ Status BeamSearch::Compute(OpKernelContext* context) const { return s; } +WhisperBeamSearch::WhisperBeamSearch(const OpKernelInfo& info) + : onnxruntime::contrib::transformers::WhisperBeamSearch(info) { + SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, + GenerationCudaDeviceHelper::TopK, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::ProcessLogits, + GenerationCudaDeviceHelper::ProcessLogits, + GenerationCudaDeviceHelper::InitBeamState, + GenerationCudaDeviceHelper::InitBeamState, + GenerationCudaDeviceHelper::CreateBeamScorer); + +#ifndef USE_ROCM + SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState, GenerationCudaDeviceHelper::InitCacheIndir); +#endif + + SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, + GenerationCudaDeviceHelper::UpdateGptFeeds); + + SetDeviceHelpers_EncoderDecoder(GenerationCudaDeviceHelper::UpdateDecoderFeeds, + GenerationCudaDeviceHelper::UpdateDecoderFeeds, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); + + SetConsoleDumper(&g_cuda_dumper); + +#ifndef USE_ROCM + cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); + + cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + + static_cast(cuda_device_prop_)->minor * 10; +#endif +} + +Status WhisperBeamSearch::ComputeInternal(OpKernelContext* context) const { + return onnxruntime::contrib::transformers::WhisperBeamSearch::Compute(context); +} + +Status WhisperBeamSearch::Compute(OpKernelContext* context) const { + auto s = ComputeInternal(context); + + if (s.IsOK()) { + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + } + } + + return s; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search.h index dda8271e3a6a0..a4370abd8af46 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.h @@ -21,6 +21,16 @@ class BeamSearch final : public onnxruntime::contrib::transformers::BeamSearch { Status ComputeInternal(OpKernelContext* context) const; }; +class WhisperBeamSearch final : public onnxruntime::contrib::transformers::WhisperBeamSearch { + public: + WhisperBeamSearch(const OpKernelInfo& info); + + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeInternal(OpKernelContext* context) const; +}; + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 07a8896210d2c..dbd7fb010462d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1315,6 +1315,296 @@ template void BufferExpansionKernelLauncher(const int32_t* input, int chunk_size, cudaStream_t stream); +// Support head_size up to 128 +constexpr unsigned int kTileSize = 32; +constexpr unsigned int kSeqTileSize = 16; + +__global__ void ReorderPastStatesKernel(float4* out_buffer, + const float4* in_buffer, + int batch_size, + int num_heads, + int max_length, + int chunked_head_size) { + __shared__ float4 tile[kSeqTileSize][kTileSize + 1]; + + const int b = blockIdx.z; + const int n = blockIdx.y; + const int s_base = blockIdx.x * kSeqTileSize; + const int s = s_base + threadIdx.y; + const int base_offset = (b * num_heads + n) * max_length * chunked_head_size; + + if (s < max_length) { + const int in_offset = base_offset + s * chunked_head_size + threadIdx.x; + tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset]; + } + + __syncthreads(); + + const int tidx = threadIdx.x + threadIdx.y * chunked_head_size; + const int tidx_x = tidx % kSeqTileSize; + const int tidx_y = tidx / kSeqTileSize; + + const int s2 = s_base + tidx_x; + + if (s2 < max_length) { + const int out_offset = base_offset + tidx_y * max_length + s2; + out_buffer[out_offset] = tile[tidx_x][tidx_y]; + } +} + +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream) { + //[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size] + const int chunked_head_size = head_size / chunk_size; + const dim3 block(chunked_head_size, kSeqTileSize); + const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size); + if (chunk_size == 4 || chunk_size == 8) { + ReorderPastStatesKernel<<>>(reinterpret_cast(out_buffer), + reinterpret_cast(in_buffer), + batch_size, + num_heads, + max_length, + chunked_head_size); + } else { + ORT_THROW("ReorderPastStatesKernelLauncher only support float or half"); + } +} + +template +__global__ void CopyCrossQKSingleDecodeStepKernel( + T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame] + T** qk_layer_pointers, + int token_index, + int num_layers, + int num_heads, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + const int pair = blockIdx.x; + const int layer_head_pair_count = gridDim.x; + const int bbm = blockIdx.y; + cross_qk_layer_head_pairs += (pair * 2); + const int layer = *cross_qk_layer_head_pairs; + const int head = *(cross_qk_layer_head_pairs + 1); + + target += ((int64_t)bbm * layer_head_pair_count + pair) * max_length * frames + ((int64_t)token_index * frames); + T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames; + + for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + dim3 block(512); + dim3 grid(cross_qk_layer_head_pair_count, batchxbeam); + typedef typename ToCudaType::MappedType CudaT; + + CopyCrossQKSingleDecodeStepKernel<<>>( + (CudaT*)cross_qk_buffer_data, + (CudaT**)qk_layer_pointers, + token_index, + num_layers, + num_heads, + cross_qk_layer_head_pairs, + frames, + max_length + ); +} + + +template +__global__ void CopyDecoderCrossQKAllStepsKernel( + int context_decoding_len, + int num_beams, + int num_return_sequences, + int max_length, + int frames_of_k, + const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames] + T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames] + const int* cache_indir_data, // [batch, num_beams, max_length] + const int32_t* beam_indices +) { + const int pair = blockIdx.y; + const int layer_head_pair_count = gridDim.y; + const int total_decoding_length = gridDim.x; + const int token_decoding_index = blockIdx.x; + const int br = blockIdx.z; + const int batch = br / num_return_sequences; + const int ret_seq_id = br % num_return_sequences; + + // get the real beam index, as the cache_indir_data did not updated in last token + const int src_beam = beam_indices[batch * num_beams + ret_seq_id] % num_beams; + + const int64_t offset_in_cache = ((int64_t)batch * num_beams + src_beam) * max_length + token_decoding_index + context_decoding_len; + int bm_mapped = ((num_beams <= 1) ? 0: ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache])); + int bi_src = batch * num_beams + bm_mapped; + + T* target = cross_qk_output + + (((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k; + const T* src = cross_qk_buffer_data + + ((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k; + for (int tid = threadIdx.x; tid < frames_of_k; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices +) { + int64_t br = (int64_t)batch_size * num_return_sequences; + ORT_ENFORCE(br < 65536L && cross_qk_layer_head_pair_count < 65536); + const int total_decoding_length = iteration_number - 1; + dim3 block(512); + dim3 grid(total_decoding_length, cross_qk_layer_head_pair_count, (unsigned)br); + typedef typename ToCudaType::MappedType CudaT; + + CopyDecoderCrossQKAllStepsKernel<<>>( + context_decoding_len, + num_beams, + num_return_sequences, + max_length, + frames_of_k, + (const CudaT*)cross_qk_buffer_data, + (CudaT*)cross_qk_output, + cache_indir_data, + beam_indices); +} + +template +__global__ void ForceDecodingIdsKernel( + float* beam_scores, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step +) { + const int num_beams = gridDim.y; + const int beam = blockIdx.y; + const int batch = blockIdx.z; + beam_scores += (((int64_t)batch * num_beams + beam)* vocab_size); // move to (batch, beam) + const int32_t id_wanted = force_ids[((int64_t)batch * id_len) + step]; + if (id_wanted < 0 || id_wanted >= vocab_size) return; + + const int32_t elements_per_block = (int32_t)blockDim.x * ElementsPerThreads; + const int32_t block_start_id = blockIdx.x * elements_per_block; + + int32_t token_id = block_start_id + (int)threadIdx.x; + #pragma unroll + for (int elem = 0; elem < ElementsPerThreads; elem++) { + if (token_id < vocab_size) { + beam_scores[token_id] = ((token_id == id_wanted) ? 0.0f : cub::FpLimits::Lowest()); + } + token_id += (int)blockDim.x; + } +} + + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream +) { + dim3 blocks(512); + constexpr int ElementsPerThreads = 4; + unsigned gridx = static_cast((vocab_size + 512 * ElementsPerThreads - 1) / (512 * ElementsPerThreads)); + dim3 grids(gridx, num_beams, batch_size); + ForceDecodingIdsKernel<<>>( + beam_scores, vocab_size, force_ids, id_len, step + ); +} + +template +__global__ void SaveNoSpeechProbsKernel( + T* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id +) { + int b = blockIdx.x * blockDim.x + threadIdx.x; + if (b < batch_size) { + int64_t src_offset = b * num_beams * vocab_size + no_speech_token_id; + result_no_speech_probs[b] = (T)(probs[src_offset]); + } +} + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +) { + int tpb = 256; + int bpg = (batch_size + 255) / 256; + + typedef typename ToCudaType::MappedType CudaT; + SaveNoSpeechProbsKernel<<>>( + (CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id); +} + +template void LaunchSaveNoSpeechProbs( + float* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + +template void LaunchSaveNoSpeechProbs( + MLFloat16* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 8c52f6fd52385..5ed5949196b29 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -213,6 +213,64 @@ void BufferExpansionKernelLauncher(const T* input, int chunk_size, cudaStream_t stream); +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream); + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length); + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices); + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream); + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4de33499c6ca..380d561bbb23c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -13,6 +13,8 @@ #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" @@ -56,19 +58,23 @@ namespace GenerationCudaDeviceHelper { // It might be better to forcefully require the same type since cast node generates // extra overhead. Status ReorderPastState( - const void* cuda_device_prop, + const void*, Tensor& past_state, Tensor& past_state_staging, Stream* stream) { ORT_ENFORCE(stream); cudaStream_t cuda_stream = reinterpret_cast(stream->GetHandle()); - cublasHandle_t cublas_handle = static_cast(stream)->cublas_handle_; const auto& past_state_shape = past_state.Shape(); const auto& past_state_dims = past_state_shape.GetDims(); const bool packed_past = past_state_dims.size() == 5; + size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0]; + size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1]; + size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2]; + size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3]; + // Copy the 'K' values into the temp staging buffer size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes(); void* past_state_staging_buffer = past_state_staging.MutableDataRaw(); @@ -79,27 +85,16 @@ Status ReorderPastState( // [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T) int64_t chunk_size = static_cast(16 / past_state.DataType()->Size()); - std::vector permutation_vector = {0, 1, 3, 2, 4}; - gsl::span permutation(permutation_vector.data(), 5); - - // "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need - size_t offset = packed_past ? 1 : 0; - TensorShape transpose_input_shape_override = {past_state_shape[offset], - past_state_shape[offset + 1], - past_state_shape[offset + 2], - past_state_shape[offset + 3] / chunk_size, - chunk_size}; - - TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1], - past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2], - chunk_size}; - - // TODO(hasesh): Explore perf tuning for this Transpose operation - return onnxruntime::cuda::Transpose::DoTranspose(*static_cast(cuda_device_prop), cuda_stream, - cublas_handle, permutation, - past_state_staging, past_state, - &transpose_input_shape_override, - &transpose_output_shape_override); + cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(), + past_state_staging_buffer, + static_cast(batch_size), + static_cast(num_heads), + static_cast(max_length), + static_cast(head_size), + static_cast(chunk_size), + cuda_stream); + + return Status::OK(); } Status InitCacheIndir(Tensor& cache_indir, Stream* stream) { @@ -210,7 +205,7 @@ Status AddToFeeds(Stream* ort_stream, ORT_ENFORCE(total_bytes > 0); cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes); + auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes, false, ort_stream); char* pinned_data = static_cast(pinned_buffer.get()); // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; @@ -426,11 +421,21 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif + const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); + if (step == 1 && is_whisper_model && parameters->no_speech_probs) { + cuda::LaunchSaveNoSpeechProbs( + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + } + + // NOTE: currently we treat extra decoding ids are same + int extra_decoding_len = static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size); + const bool need_handle_extra_decoding_ids = is_whisper_model && (!parameters->extra_decoding_ids.empty()) && (extra_decoding_len >= step); + cuda::LaunchLogitsProcessKernel( next_token_scores.data(), parameters->vocab_mask.data(), - step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - nullptr, // parameters->presence_mask.data(), + (step > extra_decoding_len + 1) ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + nullptr, // parameters->presence_mask.data(), parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -445,6 +450,50 @@ Status ProcessLogits(const OrtValue& logits, // parameters->no_repeat_ngram_size, cuda_stream); + // Whisper time stamp generation. + // TODO: implement it on GPU + bool gen_timestamp = is_whisper_model && + (parameters->logits_processor == onnxruntime::contrib::transformers::IGenerationParameters::kLogitsProcessorTypeWhisper); + if (gen_timestamp) { + // Copy next token scores to cpu memory, copy Sequences to cpu + std::vector cpu_next_token_scores(next_token_scores.size()); + gsl::span cpu_next_token_scores_span(cpu_next_token_scores.data(), cpu_next_token_scores.size()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_next_token_scores.data(), + next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(sequences->GetSequence(0).data()), + sequences->GetCurrentDeviceSequences().data(), + sequences->GetSequence(0).size_bytes() * batch_beam_size, + cudaMemcpyDeviceToHost, + cuda_stream)); + constexpr int max_initial_timestamp_index = 50; + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + time_logit_processor.Process(sequences, next_token_scores_timestamp); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(next_token_scores.data(), + cpu_next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyHostToDevice, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + + if (need_handle_extra_decoding_ids && !parameters->extra_decoding_ids.empty()) { + cuda::LaunchForceDecodingIds( + next_token_scores.data(), + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->extra_decoding_ids.data(), + static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size), + step - 1, + cuda_stream); + } + #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif @@ -807,13 +856,11 @@ Status GreedySearchProcessLogits( // Sequences generated by beam scorer is currently stored in CPU. // Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel - BufferUniquePtr sequences_buffer; + IAllocatorUniquePtr sequences_buffer; int current_sequence_length = sequences->GetSequenceLength(); if (parameters->repetition_penalty != 1.0f) { size_t bytes = SafeInt(sizeof(int32_t)) * batch_beam_size * parameters->max_length; - void* data = allocator->Alloc(bytes); - BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); - sequences_buffer = std::move(temp_buffer); + sequences_buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream)); } @@ -1196,14 +1243,14 @@ Status UpdateDecoderFeeds( if (past_present_share_buffer) { // Update past sequence length input - const ptrdiff_t past_sequence_length_idx = 2 * (static_cast(last_outputs.size()) - t5_decoder_first_present_output_idx) + t5_decoder_first_past_input_idx; + const ptrdiff_t past_sequence_length_idx = 2 * num_present_tensors + t5_decoder_first_past_input_idx; *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = current_length - 1; // Update beam search specific input for DecoderMaskedSelfAttention (cache indirection) if present // If the last input is not `past_sequence_length`, then the beam search specific inputs // for `DecoderMaskedSelfAttention` is present - if (need_cache_indir) { + if (need_cache_indir && num_beams > 1) { ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiHeadAttention with BeamSearch"); // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed @@ -1528,6 +1575,93 @@ template Status ExpandBuffer( OrtValue& expanded, bool only_copy_shape, int max_sequence_length); + +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + if (qk_layer_pointers.get() == nullptr) { + // Put all the qk pointers into gpu, as they did not change in following decoding steps + // also this help to use single kernel to process each step + qk_layer_pointers = IAllocator::MakeUniquePtr(allocator, static_cast(num_layers), false, stream); + std::vector qk_layer_data(num_layers, nullptr); + for (int layer = 0; layer < num_layers; layer++) { + qk_layer_data[layer] = cross_qks[layer].GetMutable()->MutableData(); + } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync((void*)qk_layer_pointers.get(), qk_layer_data.data(), sizeof(qk_layer_data[0]) * num_layers, + cudaMemcpyHostToDevice, cuda_stream)); + } + + auto cross_qk_layer_shape = cross_qks[0].GetMutable()->Shape(); + int64_t batchxbeam = cross_qk_layer_shape[0]; + int64_t num_heads = cross_qk_layer_shape[1]; + int64_t frames = cross_qk_layer_shape[3]; + + cuda::LaunchCopyCrossQKSingleDecodeStep( + cuda_stream, + cross_qk_buffer_data, + qk_layer_pointers.get(), + iteration_number - 2, + static_cast(batchxbeam), + num_layers, + static_cast(num_heads), + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + static_cast(frames), + max_length); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices_gpu) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + cuda::LaunchFinalizeCrossQK( + cuda_stream, + iteration_number, + context_decoding_len, + batch_size, + num_beams, + max_length, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames_of_k, + cross_qk_buffer_data, + cross_qk_output, + num_return_sequences, + cache_indir_data, + beam_indices_gpu.data()); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index f5f062d7a101b..7a718eb9f66c1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -150,6 +150,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length = 0); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc index d9014ca8f5c24..812ab0b1bcae6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc @@ -48,10 +48,12 @@ GreedySearch::GreedySearch(const OpKernelInfo& info) SetConsoleDumper(&g_cuda_dumper_greedysearch); +#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; +#endif } Status GreedySearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000000..723ff00aa815e --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000000..0fa823befa9b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000000..c43f8b7f18465 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000000..6c63a2ffed4b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_add.cc b/onnxruntime/contrib_ops/js/bias_add.cc new file mode 100644 index 0000000000000..9e70dead6a5da --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_add.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasAdd, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_add.h b/onnxruntime/contrib_ops/js/bias_add.h new file mode 100644 index 0000000000000..62a4df9bcdf34 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasAdd, BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.cc b/onnxruntime/contrib_ops/js/bias_split_gelu.cc new file mode 100644 index 0000000000000..e16aa4367d1c7 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_split_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasSplitGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.h b/onnxruntime/contrib_ops/js/bias_split_gelu.h new file mode 100644 index 0000000000000..3b3b41c0ca1f3 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasSplitGelu, BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 0bf6a4a168e68..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,8 +7,13 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -18,8 +23,14 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/skip_layer_norm.cc b/onnxruntime/contrib_ops/js/skip_layer_norm.cc index ee315f9b31e3b..f949326e1dc95 100644 --- a/onnxruntime/contrib_ops/js/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/js/skip_layer_norm.cc @@ -7,14 +7,16 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( SkipLayerNormalization, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("U", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("U", JsepSupportedFloatTypes()), SkipLayerNorm); } // namespace js diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 246b66078537a..78983ac95e672 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { Nop{}); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); if constexpr (USE_MASK) { ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index cbf24ee2f5487..ea9040aa7875f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); auto nop = Nop{}; auto addfastgelu = AddFastGelu{}; @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { params->lda, params->ldb, std::array{0}, params->ldc, nop, nop, addfastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); auto nop = Nop{}; auto fastgelu = FastGelu{}; @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { params->ldc, nop, nop, fastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index e87813fb19956..fb7091592c16e 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -34,17 +34,17 @@ constexpr int NumReduceDim = 3; template auto GetCKGroupNormNHWCTypeStringAndOps() { - using InDataType = typename CKDataTypeAdaptor::type; - using OutDataType = typename CKDataTypeAdaptor::type; - using AccDataType = typename CKDataTypeAdaptor::type; + using XDataType = typename CKDataTypeAdaptor::type; + using YDataType = typename CKDataTypeAdaptor::type; + using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; using GammaDataType = float; using BetaDataType = float; using Activation = std::conditional_t; std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; auto invoker = impl->MakeInvokerPointer(); @@ -69,6 +69,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { gamma_beta_strides, // gammaStrides gamma_beta_strides, // betaStrides in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides reduce_dims, // reduceDims params->epsilon, params->src, @@ -79,7 +81,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { nullptr, activation); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 88443478cf521..19b081881dcec 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -6,8 +6,8 @@ #ifdef USE_COMPOSABLE_KERNEL #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" #include "ck/utility/data_type.hpp" namespace onnxruntime { @@ -21,102 +21,104 @@ using F32 = float; using Swish = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceNormalization; // the interface -using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp template using device_normalization_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; template -using device_normalization_f16_instances = std::tuple< +using device_normalization_f16_instances = // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; // Use this function to get implementation -template -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { return {}; } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Swish, 5, 3>(); template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Pass, 5, 3>(); + F16, F32, F32, F16, F32, Pass, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Swish, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Pass, 5, 3>(); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu index d1dd78e3452da..6718f29268031 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 97baed34a341d..9b0ccab17b4c1 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 526d220d4be24..b7b9441ac997d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -77,7 +77,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon}; // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->stream, i, params->n, params->groups, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7bc0f99081169..0f8fe68de717a 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -29,6 +29,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -52,6 +60,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); @@ -61,12 +73,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); @@ -113,6 +124,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); @@ -139,6 +161,7 @@ KernelCreateInfo BuildKernelCreateInfo() { return info; } +// clang-format off Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing @@ -162,70 +185,73 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -238,7 +264,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -249,16 +274,25 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo - + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +312,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + }; for (auto& function_table_entry : function_table) { @@ -289,6 +324,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +// clang-format on } // namespace rocm } // namespace contrib diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index a23409292bb74..655d5014f3d60 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -22,6 +22,14 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #endif // ARM #endif // Linux @@ -135,38 +143,41 @@ void CPUIDInfo::ArmLinuxInit() { LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; return; } + is_hybrid_ = cpuinfo_get_uarchs_count() > 1; + has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + + const uint32_t core_cnt = cpuinfo_get_cores_count(); + core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); + is_armv8_narrow_ld_.resize(core_cnt, false); + for (uint32_t c = 0; c < core_cnt; c++) { + const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); + if (proc == nullptr) { + continue; + } + const struct cpuinfo_core* corep = proc->core; + if (corep == nullptr) { + continue; + } + auto coreid = proc->linux_id; + auto uarch = corep->uarch; + core_uarchs_[coreid] = uarch; + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_[coreid] = true; + } + } #else pytorch_cpuinfo_init_ = false; -#endif + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; - if (pytorch_cpuinfo_init_) { - is_hybrid_ = cpuinfo_get_uarchs_count() > 1; - has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); - has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - const uint32_t core_cnt = cpuinfo_get_cores_count(); - core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); - is_armv8_narrow_ld_.resize(core_cnt, false); - for (uint32_t c = 0; c < core_cnt; c++) { - const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); - if (proc == nullptr) { - continue; - } - const struct cpuinfo_core* corep = proc->core; - if (corep == nullptr) { - continue; - } - auto coreid = proc->linux_id; - auto uarch = corep->uarch; - core_uarchs_[coreid] = uarch; - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_[coreid] = true; - } - } - } else { - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); - has_fp16_ |= has_arm_neon_dot_; - } + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + +#endif } #elif defined(_WIN32) @@ -260,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); has_fp16_ |= has_arm_neon_dot_; + /* TODO: implement them when hw+sw is available for testing these features */ + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 386db347c669d..a15c75104b83a 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -28,6 +28,8 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } uint32_t GetCurrentCoreIdx() const; @@ -121,6 +123,8 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/common/cpuid_uarch.cc b/onnxruntime/core/common/cpuid_uarch.cc index 52baad739441b..16634b2bc8744 100644 --- a/onnxruntime/core/common/cpuid_uarch.cc +++ b/onnxruntime/core/common/cpuid_uarch.cc @@ -3,7 +3,8 @@ #include "core/common/cpuid_uarch.h" -#include "core/common/logging/logging.h" +#include // For std::cerr. + // Writing to stderr instead of logging because logger may not be initialized yet. namespace onnxruntime { @@ -137,7 +138,7 @@ void decodeMIDR( break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,7 +157,7 @@ void decodeMIDR( break; // #endif default: - LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,7 +173,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; default: - LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,7 +188,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_cortex_a76; break; default: - LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,7 +200,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xscale; break; default: - LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,7 +216,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_carmel; break; default: - LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,7 +226,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xgene; break; default: - LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,7 +298,7 @@ void decodeMIDR( break; // #endif /* ARM64 && !defined(__ANDROID__) */ default: - LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,8 +344,9 @@ void decodeMIDR( *uarch = cpuinfo_uarch_exynos_m5; break; default: - LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Samsung CPU variant 0x" + << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -355,12 +357,12 @@ void decodeMIDR( *uarch = cpuinfo_uarch_pj4; break; default: - LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr; + std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 6e0eb460d2a63..eca1221e84cb8 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -37,5 +38,32 @@ inline InlinedVector SplitString(std::string_view string_to_sp return result; } +/** + * Trim a string from start inplace. + * @param s The string to trim. + */ +inline void TrimStringFromLeft(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +/** + * Trim a string from end inplace. + * @param s The string to trim. + */ +inline void TrimStringFromRight(std::string& s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +/** + * Trim a string from both ends. + * @param s The string to trim. + * @return The trimmed string. + */ +inline std::string TrimString(std::string s) { + TrimStringFromRight(s); + TrimStringFromLeft(s); + return s; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index f29ab19608934..10e117267e14b 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -562,7 +562,7 @@ static ptrdiff_t CalculateParallelForBlock(const ptrdiff_t n, const Eigen::Tenso constexpr ptrdiff_t max_oversharding_factor = 4; ptrdiff_t block_size = Eigen::numext::mini( n, - Eigen::numext::maxi(Eigen::divup(n, max_oversharding_factor * num_threads), static_cast(block_size_f))); + Eigen::numext::maxi(Eigen::numext::div_ceil(n, max_oversharding_factor * num_threads), static_cast(block_size_f))); const ptrdiff_t max_block_size = Eigen::numext::mini(n, 2 * block_size); if (block_align) { @@ -571,19 +571,19 @@ static ptrdiff_t CalculateParallelForBlock(const ptrdiff_t n, const Eigen::Tenso block_size = Eigen::numext::mini(n, new_block_size); } - ptrdiff_t block_count = Eigen::divup(n, block_size); + ptrdiff_t block_count = Eigen::numext::div_ceil(n, block_size); // Calculate parallel efficiency as fraction of total CPU time used for // computations: double max_efficiency = - static_cast(block_count) / (Eigen::divup(block_count, num_threads) * num_threads); + static_cast(block_count) / (Eigen::numext::div_ceil(block_count, num_threads) * num_threads); // Now try to increase block size up to max_block_size as long as it // doesn't decrease parallel efficiency. for (ptrdiff_t prev_block_count = block_count; max_efficiency < 1.0 && prev_block_count > 1;) { // This is the next block size that divides size into a smaller number // of blocks than the current block_size. - ptrdiff_t coarser_block_size = Eigen::divup(n, prev_block_count - 1); + ptrdiff_t coarser_block_size = Eigen::numext::div_ceil(n, prev_block_count - 1); if (block_align) { ptrdiff_t new_block_size = block_align(coarser_block_size); assert(new_block_size >= coarser_block_size); @@ -593,11 +593,11 @@ static ptrdiff_t CalculateParallelForBlock(const ptrdiff_t n, const Eigen::Tenso break; // Reached max block size. Stop. } // Recalculate parallel efficiency. - const ptrdiff_t coarser_block_count = Eigen::divup(n, coarser_block_size); + const ptrdiff_t coarser_block_count = Eigen::numext::div_ceil(n, coarser_block_size); assert(coarser_block_count < prev_block_count); prev_block_count = coarser_block_count; const double coarser_efficiency = - static_cast(coarser_block_count) / (Eigen::divup(coarser_block_count, num_threads) * num_threads); + static_cast(coarser_block_count) / (Eigen::numext::div_ceil(coarser_block_count, num_threads) * num_threads); if (coarser_efficiency + 0.01 >= max_efficiency) { // Taking it. block_size = coarser_block_size; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 0bf27fdf5e5dc..9556e056dedc0 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -320,7 +320,7 @@ class PlannerImpl { return false; } - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(node, ci); auto input_args = node.InputDefs(); for (auto& pair : alias_map) { if (pair.second == output_arg_num) { @@ -829,6 +829,34 @@ class PlannerImpl { return p_provider->GetOrtDeviceByMemType(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault); } + std::vector> GetAliasMap(const Node& node, const KernelCreateInfo& kernel_create_info) { + ORT_ENFORCE(kernel_create_info.kernel_def != nullptr, "KernelDef is null for node: ", node.Name()); +#ifdef ENABLE_TRAINING_TORCH_INTEROP + if ((node.OpType().compare("PythonOp") == 0 || node.OpType().compare("PythonOpGrad") == 0) && + node.Domain() == kMSDomain) { + const auto& attrs = node.GetAttributes(); + auto attr_it = attrs.find("tensor_reuse_map"); + if (attr_it != attrs.end()) { + const auto& inplace_map = attr_it->second.ints(); + std::vector> alias_map; + alias_map.reserve(inplace_map.size()); + for (int i = 0; i < inplace_map.size(); ++i) { + int output_index = i; + int input_index = inplace_map[i]; + if (input_index == -1) { + // skip because no reuse for this output + continue; + } + alias_map.emplace_back(std::make_pair(input_index, output_index)); + } + return alias_map; + } + } +#endif + + return kernel_create_info.kernel_def->Alias(); + } + void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer, const InitializedTensorSet& weights, const KernelCreateInfoMap& kernel_create_info_map, @@ -1084,7 +1112,7 @@ class PlannerImpl { } bool found_reusable = false; - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(*node, ci); auto input_args = node->InputDefs(); for (auto* input_arg : input_args) { OrtValueIndex input_idx_global{}; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 3d971e6aa29a2..ef68b88187e08 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -9,6 +9,7 @@ #include "onnx/defs/data_type_utils.h" #include "core/framework/op_kernel.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE::Utils; @@ -77,7 +78,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe ORT_THROW_IF_ERROR(node->ForEachWithIndex( node->OutputDefs(), [&](const NodeArg& node_arg, size_t out_index) { - if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) { + if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) { cpu_output_args.insert(&node_arg); auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..e4fe0c7564548 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -13,7 +13,9 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" #include "core/graph/function.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/model.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -129,6 +131,21 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; + +auto get_capabilities = [](const IExecutionProvider& ep, + const GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + + // In theory an EP could return an empty capability. Remove those. + capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), + [](const std::unique_ptr& capability) { + return !capability || !capability->sub_graph; + }), + capabilities.end()); + + return capabilities; +}; } // namespace static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { @@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - auto get_capabilities = [](const IExecutionProvider& ep, - const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); - - // In theory an EP could return an empty capability. Remove those. - capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), - [](const std::unique_ptr& capability) { - return !capability || !capability->sub_graph; - }), - capabilities.end()); - - return capabilities; - }; - const auto& kernel_registry_mgr = params.kernel_registry_mgr.get(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, @@ -177,9 +179,9 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // Run layout transformer for EPs with preferred layout of NHWC // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get() && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); @@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) + +// This function queries the capabilities for a given EP, but it does not assign the nodes. +// It also does not perform layout transformation. This will be done during normal partitioning. +static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, + const KernelRegistryManager& kernel_registry_mgr, + const IExecutionProvider& current_ep, + std::vector>& capabilities) { + const auto& ep_type = current_ep.Type(); + + const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); + const KernelLookup kernel_lookup{ep_type, + kernel_registries_for_ep, + kernel_registry_mgr.GetKernelTypeStrResolver()}; + + // TODO: Provide EP with a capability to look inside the functions. + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + + return Status::OK(); +} + /** * Check if a node can be placed on a specific provider. * Do nothing if the node is already assigned @@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { // successfully inlined, we re-run the partitioner on the modified graph. // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating // using graph.Nodes(). - std::vector nodes_to_inline; + InlinedVector nodes_to_inline; for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) { nodes_to_inline.push_back(&node); @@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } +static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_mgr, + Graph& graph, + InlinedHashSet& not_inlined, + size_t& inlined_count) { + // handle testing edge case where optimizers or constant lifting results in graph with no nodes. + // doing it here saves all providers checking for this in GetCapability + if (graph.NumberOfNodes() == 0) { + return Status::OK(); + } + + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + // we pass through the FuncManager from the top level graph + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_mgr, + *subgraph, + not_inlined, + inlined_count)); + } + } + + // Gather the candidates + InlinedVector inline_candidates; + for (auto& node : graph.Nodes()) { + if (node.CanBeInlined()) { + inline_candidates.push_back(node.Index()); + } + } + + if (inline_candidates.empty()) { + return Status::OK(); + } + + // Find out all the nodes that are already taken + const GraphViewer graph_viewer(graph); + + InlinedHashSet claimed_by_ep; + for (const auto& ep : execution_providers) { + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + for (auto& capability : capabilities) { + const auto& nodes = capability->sub_graph->nodes; + if (nodes.size() == 1) { + // Single node capability. + ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0])); + } else { + // Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl. + if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) { + return claimed_by_ep.count(node_index) == 0; + })) { + claimed_by_ep.insert(nodes.cbegin(), nodes.cend()); + } + } + } + } + + // TODO: Insert version check. We need to collect all the versions + // that imported by the model. If the version is not supported by + // the model, we can not inline it. + + for (auto node_index : inline_candidates) { + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + if (claimed_by_ep.count(node_index) == 0) { + ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); + ++inlined_count; + } else { + // OpType is the same as function name. + auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType()); + ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); + } + } + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -693,6 +794,50 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params, return Status::OK(); } +#ifndef ORT_MINIMAL_BUILD + +Status GraphPartitioner::InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) const { + const auto local_functions_num = model.GetModelLocalFunctionTemplates().size(); + const bool is_there_local_functions = local_functions_num > 0; + + if (!is_there_local_functions) { + LOGS(logger, INFO) << "This model does not have any local functions defined. AOT Inlining is not performed"; + return Status::OK(); + } + + auto& graph = model.MainGraph(); + InlinedHashSet not_inlined; + do { + size_t inlined_count = 0; + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_manager, + graph, + not_inlined, + inlined_count)); + + if (inlined_count == 0) { + break; + } + + ORT_RETURN_IF_ERROR(graph.Resolve()); + } while (true); + + model.RemoveLocalFunctionsProtos(not_inlined); + + LOGS(logger, INFO) + << "AOT inlining completed. (" << (local_functions_num - model.GetModelLocalFunctionTemplates().size()) + << ") functions of (" + << local_functions_num + << ") pruned."; + + return Status::OK(); +} + +#endif + Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, Mode mode, diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 36a27e906c651..4fc85c2588260 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -12,6 +12,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; +class Model; class GraphPartitioner { public: @@ -33,6 +34,28 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; +#ifndef ORT_MINIMAL_BUILD + ///

+ // Ahead of Time Function inlining. The main purpose of the function is to inline as many + // functions as possible and delete locally defined functions to reduce the size of the model. + // This would make other optimizations to be more effective. + // + // This function performs GetCapability on the graph and its subgraphs bottom up + // and inlines any functions that are not claimed by any of the execution providers. + // This function does not attempt to run layout transformation, and it does not assign EPs. + // The latter will be done by graph partitioning after Level1 optimizations are done. + /// + /// model instance + /// execution providers considered + /// registry manager + /// session logger + /// + Status InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) const; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 38c8a4a4e3d5e..b2ef853119588 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -62,9 +62,15 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; - errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; - if (!status.IsOK()) errormsg << status.ErrorMessage(); + errormsg << prefix; + const auto& domain = node.Domain(); + if (!domain.empty()) { + errormsg << domain << "."; + } + errormsg << node.OpType() << "(" << node.SinceVersion() << ")" + << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + if (!status.IsOK()) + errormsg << status.ErrorMessage(); return errormsg.str(); }; diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index ea93db58339c7..4f5fa9910b5df 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -53,128 +53,200 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0xbc, 0x06, 0x00, 0x00, - 0x4c, 0x02, 0x00, 0x00, 0xe0, 0x01, 0x00, 0x00, 0xe0, 0x00, 0x00, 0x00, 0x14, 0x06, 0x00, 0x00, - 0x88, 0x01, 0x00, 0x00, 0xb8, 0x05, 0x00, 0x00, 0x1c, 0x05, 0x00, 0x00, 0x18, 0x07, 0x00, 0x00, - 0xcc, 0x04, 0x00, 0x00, 0x0c, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x54, 0x05, 0x00, 0x00, - 0x3c, 0x06, 0x00, 0x00, 0xf8, 0x02, 0x00, 0x00, 0x7c, 0x02, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x38, 0x03, 0x00, 0x00, 0xec, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, + 0x4c, 0x0b, 0x00, 0x00, 0xac, 0x08, 0x00, 0x00, 0xd0, 0x0a, 0x00, 0x00, 0x10, 0x06, 0x00, 0x00, + 0xa8, 0x07, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x44, 0x07, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xf8, 0x07, 0x00, 0x00, 0x78, 0x09, 0x00, 0x00, + 0x14, 0x01, 0x00, 0x00, 0x50, 0x06, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, 0xf4, 0x08, 0x00, 0x00, + 0x8c, 0x03, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0x84, 0x06, 0x00, 0x00, 0xcc, 0x03, 0x00, 0x00, + 0x60, 0x05, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0x08, 0x04, 0x00, 0x00, + 0xe0, 0x09, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xf4, 0xff, 0xff, + 0x08, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xda, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf4, 0xff, 0xff, + 0xd8, 0xf4, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf5, 0xff, 0xff, 0xa4, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfc, 0xf4, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x2c, 0xf5, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4e, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x48, 0xf5, 0xff, 0xff, 0xc8, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x30, 0xf5, 0xff, 0xff, 0x6c, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x39, 0x00, 0x00, 0x9c, 0xf5, 0xff, 0xff, 0x3c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc2, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x94, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc4, 0xf5, 0xff, 0xff, + 0xe8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb4, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xac, 0xf5, 0xff, 0xff, + 0xe8, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf6, 0xff, 0xff, 0xac, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x36, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf8, 0xf5, 0xff, 0xff, 0x34, 0xf6, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, + 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x74, 0xf6, 0xff, 0xff, + 0x38, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x64, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, + 0x98, 0xf6, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbe, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0xe4, 0xf6, 0xff, 0xff, + 0x2c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xcc, 0xf6, 0xff, 0xff, + 0x08, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x30, 0xf7, 0xff, 0xff, 0xe0, 0x08, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x56, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0xf7, 0xff, 0xff, 0x54, 0xf7, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, + 0x78, 0xf7, 0xff, 0xff, 0x98, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x60, 0xf7, 0xff, 0xff, 0x9c, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, - 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0x20, 0xf9, 0xff, 0xff, 0xf0, 0x06, 0x00, 0x00, + 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xd0, 0xf7, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xf9, 0xff, 0xff, 0x44, 0xf9, 0xff, 0xff, + 0xf6, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf7, 0xff, 0xff, 0xf4, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x6c, 0xf9, 0xff, 0xff, 0xa4, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5a, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x54, 0xf9, 0xff, 0xff, 0x90, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0xb4, 0xf9, 0xff, 0xff, - 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa2, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xf8, 0xff, 0xff, 0xf4, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xf8, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf8, 0xff, 0xff, 0x40, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, + 0x68, 0xf8, 0xff, 0xff, 0xa8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xf8, 0xff, 0xff, 0x8c, 0xf8, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x0c, 0x01, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, + 0xd8, 0xf8, 0xff, 0xff, 0xdc, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf4, 0xf8, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x22, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf8, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x24, 0xf9, 0xff, 0xff, + 0xe4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0xf9, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x40, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x68, 0xf9, 0xff, 0xff, 0x70, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, + 0x60, 0xf9, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x90, 0xf9, 0xff, 0xff, 0x1c, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x80, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x78, 0xf9, 0xff, 0xff, 0xb4, 0xf9, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0xd8, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xb4, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xfa, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0xfa, 0xff, 0xff, 0xf4, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xfa, 0xff, 0xff, 0x40, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, - 0x68, 0xfa, 0xff, 0xff, 0x3c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xfa, 0xff, 0xff, 0x8c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xb4, 0xfa, 0xff, 0xff, - 0x00, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xfc, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x04, 0xfa, 0xff, 0xff, + 0x84, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x20, 0xfa, 0xff, 0xff, 0xf0, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbe, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, + 0x46, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xfa, 0xff, 0xff, 0x44, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, - 0x31, 0x31, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x98, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x38, 0xfb, 0xff, 0xff, 0xd8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x26, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x20, 0xfb, 0xff, 0xff, 0x5c, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, - 0x88, 0xfb, 0xff, 0xff, 0x88, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xfb, 0xff, 0xff, 0xac, 0xfb, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x00, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xfb, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x31, 0x00, 0x00, 0x00, 0xfc, 0xfb, 0xff, 0xff, 0x14, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xea, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xe4, 0xfb, 0xff, 0xff, 0x20, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x38, 0x01, 0x00, 0x00, 0xdc, 0x00, 0x00, 0x00, - 0xa8, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, - 0x48, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, - 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, - 0x76, 0x3a, 0x31, 0x00, 0x6c, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x90, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xfc, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, - 0xb8, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0c, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xe0, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd6, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x3c, 0xfd, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x10, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfd, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, - 0x6c, 0xfd, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x40, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, - 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x58, 0xfd, 0xff, 0xff, 0x94, 0xfd, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, - 0xb8, 0xfd, 0xff, 0xff, 0x58, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa6, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xa0, 0xfd, 0xff, 0xff, 0xdc, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, - 0xa0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xf2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xec, 0xfd, 0xff, 0xff, - 0x28, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x50, 0xfe, 0xff, 0xff, 0xc0, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x3e, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x74, 0xfe, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, - 0x00, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x92, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfe, 0xff, 0xff, - 0xc8, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xfe, 0xff, 0xff, 0x20, 0x01, 0x00, 0x00, + 0x31, 0x31, 0x00, 0x00, 0x6c, 0xfa, 0xff, 0xff, 0xc4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x88, 0xfa, 0xff, 0xff, 0x88, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xae, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x70, 0xfa, 0xff, 0xff, 0xac, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xde, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfe, 0xff, 0xff, 0x14, 0xff, 0xff, 0xff, + 0xf6, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x3c, 0xff, 0xff, 0xff, 0xd4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2a, 0xff, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x24, 0xff, 0xff, 0xff, 0x60, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0xf4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfb, 0xff, 0xff, 0x40, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x68, 0xfb, 0xff, 0xff, 0xa8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xfb, 0xff, 0xff, 0x8c, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfb, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xa4, 0xfb, 0xff, 0xff, 0xe0, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, + 0x08, 0xfc, 0xff, 0xff, 0x08, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xf0, 0xfb, 0xff, 0xff, 0x2c, 0xfc, 0xff, 0xff, 0x04, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x18, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x48, 0xfc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x30, 0x00, 0x00, 0x7c, 0xfc, 0xff, 0xff, 0x30, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfc, 0xff, 0xff, 0x94, 0xfc, 0xff, 0xff, + 0x44, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xba, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfc, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, 0x4c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xd8, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, + 0x00, 0x00, 0x00, 0x00, 0x0c, 0xfd, 0xff, 0xff, 0xcc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfd, 0xff, 0xff, + 0x78, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xfd, 0xff, 0xff, + 0x58, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x80, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0xfd, 0xff, 0xff, 0x68, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xce, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xfd, 0xff, 0xff, 0xcc, 0xfd, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, + 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x28, 0xfe, 0xff, 0xff, 0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, 0x40, 0xfe, 0xff, 0xff, 0x98, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x66, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x68, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, + 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0xa4, 0xfe, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x94, 0xfe, 0xff, 0xff, 0xd0, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfe, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xd0, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, - 0x88, 0xff, 0xff, 0xff, 0x88, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xff, 0xff, 0xff, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0xdc, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x28, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x20, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xff, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xff, 0xff, 0xff, 0x74, 0xff, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, + 0x00, 0x00, 0x00, 0x00, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index 38d67eb0e0c72..c3deb94300e78 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -182,7 +182,7 @@ ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(float, floats) ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(int64_t, ints) template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { +Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { gsl::span span; Status status = this->GetAttrsAsSpan(name, span); if (status.IsOK()) { @@ -193,7 +193,7 @@ MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& na } template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrsStringRefs( +Status OpNodeProtoHelper::GetAttrsStringRefs( const std::string& name, std::vector>& refs) const { const AttributeProto* attr = TryGetAttribute(name); diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index aff90b8d40bde..8deeb4c2b8b64 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -148,6 +148,10 @@ struct SessionOptions { std::shared_ptr custom_op_libs; void AddCustomOpLibraryHandle(PathString library_name, void* library_handle); #endif + + // User specified logging func and param + OrtLoggingFunction user_logging_function = nullptr; + void* user_logging_param = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index f0e5fbbd38721..6244d426450a2 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, nullptr, allocators_); + logger_, profiler_, sess_options_, + prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index df3a7afebc176..df11fe8302aef 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -455,11 +455,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& // utils::CopyOneInputAcrossDevices is happy. auto& input_map = session_state.GetInputNodeInfoMap(); - auto end_map = input_map.cend(); for (const auto& graph_input : graph_inputs) { const auto& name = graph_input->Name(); - if (input_map.find(name) == end_map) { + if (input_map.find(name) == input_map.cend()) { // dummy entry for an input that we didn't find a use of in the graph. log it in case that's a bug. // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index a9e5038e19e02..36f03a9b1046a 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,80 +27,68 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -size_t Tensor::CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides) { -#ifdef ENABLE_STRIDED_TENSORS - int64_t shape_size = 1; - if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); - shape_size = GetSizeFromStrides(shape, strides); - } else { - shape_size = shape.Size(); - } -#else - ORT_ENFORCE(strides.empty(), "Strided tensor is supported for training only for now."); +size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { int64_t shape_size = shape.Size(); -#endif - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); if (shape_size > 0) { SafeInt len = 0; - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), p_type->Size(), &len)) + if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) ORT_THROW("tensor failed memory size calculation"); return len; } + return 0; } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset, gsl::span strides) - : alloc_info_(alloc) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, nullptr, offset, strides); + : alloc_info_(location) { + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, nullptr, offset, strides); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides) +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator) : alloc_info_(allocator->Info()) { - ORT_ENFORCE(p_type != nullptr); - size_t len = Tensor::CalculateTensorStorageSize(p_type, shape, strides); + ORT_ENFORCE(elt_type != nullptr); + size_t len = Tensor::CalculateTensorStorageSize(elt_type, shape); void* p_data = nullptr; if (len > 0) { p_data = allocator->Alloc(len); } - Init(p_type, shape, p_data, allocator, 0L, strides); + Init(elt_type, shape, p_data, allocator, 0L); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset, gsl::span strides) : alloc_info_(deleter->Info()) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, deleter, offset, strides); + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, deleter, offset, strides); } void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, - OrtValue& ort_value, gsl::span strides) { - auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator), strides); + OrtValue& ort_value) { + auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator)); auto ml_tensor = DataTypeImpl::GetType(); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, location, offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, location, offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr allocator, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, std::move(allocator), offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, std::move(allocator), offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } @@ -123,27 +111,31 @@ size_t Tensor::SizeInBytes() const { return ret; } -void Tensor::Init(MLDataType p_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, ptrdiff_t offset, - gsl::span strides) { +void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, + ptrdiff_t offset, gsl::span strides) { int64_t shape_size = shape.Size(); - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); - dtype_ = p_type->AsPrimitiveDataType(); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); + + dtype_ = elt_type->AsPrimitiveDataType(); ORT_ENFORCE(dtype_ != nullptr, - "Tensor is expected to contain one of the primitive data types. Got: ", DataTypeImpl::ToString(p_type)); + "Tensor is expected to contain one of the primitive data types. Got: ", + DataTypeImpl::ToString(elt_type)); shape_ = shape; p_data_ = p_raw_data; - // if caller passed in a deleter, that means this tensor own this buffer - // we will release the buffer when this tensor is deconstructed. + // if caller passed in a deleter we now own p_data_ and must free it in the dtor buffer_deleter_ = std::move(deleter); // for string tensors, if this tensor own the buffer (caller passed in the deleter) // do the placement new for strings on pre-allocated buffer. if (buffer_deleter_ && IsDataTypeString()) { utils::ConstructStrings(p_data_, shape_size); } + byte_offset_ = offset; + #ifdef ENABLE_STRIDED_TENSORS if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); + ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size."); strides_.assign(strides.begin(), strides.end()); is_contiguous_ = CheckIsContiguous(); } @@ -156,13 +148,21 @@ Tensor::Tensor(Tensor&& other) noexcept : p_data_(other.p_data_), buffer_deleter_(other.buffer_deleter_), shape_(other.shape_), +#ifdef ENABLE_STRIDED_TENSORS + strides_(other.strides_), + is_contiguous_(other.is_contiguous_), +#endif dtype_(other.dtype_), alloc_info_(other.alloc_info_), byte_offset_(other.byte_offset_) { - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; other.buffer_deleter_ = nullptr; + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif other.byte_offset_ = 0; } @@ -170,19 +170,28 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { if (this != &other) { ReleaseBuffer(); - dtype_ = other.dtype_; + p_data_ = other.p_data_; + buffer_deleter_ = other.buffer_deleter_; shape_ = other.shape_; +#ifdef ENABLE_STRIDED_TENSORS + strides_ = other.strides_; + is_contiguous_ = other.is_contiguous_; +#endif + dtype_ = other.dtype_; alloc_info_ = other.alloc_info_; byte_offset_ = other.byte_offset_; - p_data_ = other.p_data_; - buffer_deleter_ = other.buffer_deleter_; - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; - other.byte_offset_ = 0; other.buffer_deleter_ = nullptr; + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.byte_offset_ = 0; } + return *this; } diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index f3e1acbbe523d..6e11bfe1ac8ea 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -85,6 +85,16 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, + _In_ struct OrtTensorTypeAndShapeInfo* info, + _In_ const char** names, _In_ size_t dim_params_length) { + info->dim_params.clear(); + for (size_t idx = 0; idx < dim_params_length; ++idx) { + info->dim_params.push_back(names[idx]); + } + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { API_IMPL_BEGIN diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 08ed811d9ac38..fd32aaedcc2ee 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -87,6 +87,7 @@ bool operator!=(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, } // namespace ONNX_NAMESPACE +namespace onnxruntime { namespace { // This function doesn't support string tensors @@ -162,9 +163,9 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. // This function does not unpack string_data of an initializer tensor -static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const ORTCHAR_T* tensor_proto_dir, - std::vector& unpacked_tensor) { +Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const ORTCHAR_T* tensor_proto_dir, + std::vector& unpacked_tensor) { std::basic_string external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; @@ -184,10 +185,52 @@ static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tenso return Status::OK(); } + +// TODO(unknown): Change the current interface to take Path object for model path +// so that validating and manipulating path for reading external data becomes easy +Status TensorProtoToOrtValueImpl(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer* m, AllocatorPtr alloc, + OrtValue& value) { + if (m && m->GetBuffer() == nullptr) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "MemBuffer has not been allocated."); + } + + // to construct a Tensor with std::string we need to pass an allocator to the Tensor ctor + // as the contents of each string needs to be allocated and freed separately. + ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); + if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING && (m || !alloc)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor requires allocator to be provided."); + } + + // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). + TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + + std::unique_ptr tensor; + + if (m) { + tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); + + if (tensor->SizeInBytes() > m->GetLen()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", + tensor->SizeInBytes(), ", Got ", m->GetLen()); + } + } else { + tensor = std::make_unique(type, tensor_shape, alloc); + } + + ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensor)); + + auto ml_tensor = DataTypeImpl::GetType(); + value.Init(tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + return Status::OK(); +} + } // namespace -namespace onnxruntime { namespace utils { + #if !defined(ORT_MINIMAL_BUILD) static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, const ORTCHAR_T* tensor_proto_dir, @@ -810,7 +853,7 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, * @param env * @param model_path * @param tensor_proto tensor data in protobuf format - * @param tensorp pre-allocated tensor object, where we store the data + * @param tensor pre-allocated tensor object, where we store the data * @return */ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, @@ -824,7 +867,7 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProto type ", DataTypeImpl::ToString(source_type), - " can not be writen into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); + " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); } // find raw data in proto buf @@ -900,44 +943,18 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, return Status::OK(); } -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 6239) -#endif -// TODO: Change the current interface to take Path object for model path -// so that validating and manipulating path for reading external data becomes easy -Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - const MemBuffer& m, OrtValue& value) { - if (m.GetBuffer() == nullptr) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "TensorProtoToMLValue() must take a pre-allocated MemBuffer!"); - } - - ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); - if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor can not use pre-allocated buffer"); - } - - // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). - TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - std::unique_ptr tensorp = std::make_unique(type, tensor_shape, m.GetBuffer(), m.GetAllocInfo()); - if (tensorp->SizeInBytes() > m.GetLen()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", - tensorp->SizeInBytes(), ", Got ", m.GetLen()); - } - - ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensorp)); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer& m, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); +} - auto ml_tensor = DataTypeImpl::GetType(); - value.Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - return Status::OK(); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + AllocatorPtr alloc, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } -#ifdef _MSC_VER -#pragma warning(pop) -#pragma warning(disable : 6239) -#endif + #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 6f1b60dd2982e..000502ba47594 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -39,13 +39,28 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); /** - * deserialize a TensorProto into a preallocated memory buffer. - * \param tensor_proto_path A local file path of where the 'input' was loaded from. Can be NULL if the tensor proto doesn't - * have any external data or it was loaded from current working dir. This path could be either a - * relative path or an absolute path. + * deserialize a TensorProto into a preallocated memory buffer on CPU. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. */ -common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, - const ONNX_NAMESPACE::TensorProto& input, const MemBuffer& m, OrtValue& value); +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + const MemBuffer& m, OrtValue& value); + +/** + * deserialize a TensorProto into a buffer on CPU allocated using 'alloc'. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \param alloc Allocator to use for allocating the buffer. Must allocate CPU based memory. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. + */ +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + AllocatorPtr alloc, OrtValue& value); + /** * @brief Deserialize a TensorProto into a preallocated empty Tensor * @param env diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 96b4cc53a022c..6d2dd641f6bc6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -232,14 +232,15 @@ class TunableOp { return timer.Duration() / num_iter; } - static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op.IsSupported(param); + // Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be + // processed by onnxruntime. We return Status to avoid the construction of op and params signature string. + static Status IsSupported(Op& op, const ParamsT* params) { + Status status = op.IsSupported(params); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { - LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage(); - return false; + return status; } ORT_THROW_IF_ERROR(status); - return true; + return status; } protected: @@ -250,9 +251,9 @@ class TunableOp { int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { ITuningContext* ctx = params->TuningContext(); auto op_sig = Signature(); - auto param_sig = params->Signature(); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; - auto min_time = std::numeric_limits::infinity(); + auto params_sig = params->Signature(); + LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')'; + auto min_duration_ms = std::numeric_limits::infinity(); int id = -1; constexpr const int max_tuning_iter = 100; @@ -260,30 +261,32 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); - if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; + auto status = IsSupported(candidate, params); + if (!status.IsOK()) { + LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")"; + LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage(); continue; } WarmUp(candidate, params); auto approx_duration = Profile(candidate, params, approx_num_iter); - if (approx_duration > 2 * min_time) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i; + if (approx_duration > 2 * min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i; continue; } int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration))); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times."; - - auto time = Profile(candidate, params, tuning_iter); - if (time < min_time) { - min_time = time; + auto duration_ms = Profile(candidate, params, tuning_iter); + if (duration_ms < min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". " + << duration_ms << "ms, " << tuning_iter << " iters."; + min_duration_ms = duration_ms; id = static_cast(i); } } ORT_ENFORCE(id >= 0, "Could not find viable op"); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id; + LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")"; std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index b6dd8517341bb..23fe5e1cd3d96 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -68,6 +68,7 @@ bool ProviderIsCpuBased(const std::string& provider_type) { provider_type == onnxruntime::kSnpeExecutionProvider || provider_type == onnxruntime::kQnnExecutionProvider || provider_type == onnxruntime::kXnnpackExecutionProvider || + provider_type == onnxruntime::kAzureExecutionProvider || provider_type == onnxruntime::utils::kInternalTestingExecutionProvider; } @@ -1024,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index); + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + } +#else + ORT_UNUSED_PARAMETER(node); +#endif + + return false; +} + +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) { + if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) { + return true; + } + +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { + const auto& attrs = node.GetAttributes(); + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); + std::string op_name = attrs.at("operator").s(); + std::string overload_name = ""; + if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) { + overload_name = attrs.at("overload_name").s(); + } + + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index ea6a629f87cb8..f0b1b9109d405 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet bool sync_subgraph_fetches = false); bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); template constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e5956a575d73d..b97fb0d2899fc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -171,10 +171,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2] * query_dims[4]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 2)) { + } else if (hasInputShape(ctx, 2)) { auto& value_shape = getInputShape(ctx, 2); auto& value_dims = value_shape.dim(); if (value_dims.size() != 3 && value_dims.size() != 4) { @@ -192,10 +189,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c ? (dmmha_packing ? value_dims[2] / 3 : value_dims[2]) : value_dims[1] * value_dims[3]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 1)) { + } else if (hasInputShape(ctx, 1)) { auto& key_shape = getInputShape(ctx, 1); if (key_shape.dim().size() == 5) { // packed KV ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); @@ -217,7 +211,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else { if (sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = sequence_length + past_shape.dim(3).dim_value(); + int64_t total_sequence_length = sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -233,6 +227,59 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c } } +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { + // Output 0 has shape (batch_size, sequence_length, hidden_size) + + // Q, K and V: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, kv_hidden_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, kv_hidden_size) + + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + if (hasInputShape(ctx, 0)) { + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims.size() != 3) { + fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); + } + + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = query_dims[2]; + updateOutputShape(ctx, 0, output_shape); + return; + } else { + fail_shape_inference("Missing input 2 (value)"); + } + } + + if (ctx.getNumOutputs() > 1) { // has present output + if (hasInputShape(ctx, past_key_index)) { + auto& past_shape = getInputShape(ctx, past_key_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 4) { + fail_shape_inference("The past_key input shall be 4 dimensions"); + } + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + } + } +} + constexpr const char* Attention_ver1_doc = R"DOC( Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -746,6 +793,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("output_qk", + "Need output the cross attention MatMul(Q, K)", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, 1, hidden_size) or packed QKV with shape " @@ -823,7 +874,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(1, "present_key", - "past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", @@ -831,12 +882,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Output(2, "present_value", - "past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", "T", OpSchema::Optional) + .Output(3, + "qk", + "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "V", + OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") @@ -889,7 +946,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), " + "or (batch_size, sequence_length, total_sequence_length)", "M", OpSchema::Optional) .Input(5, @@ -930,6 +988,80 @@ ONNX_MS_OPERATOR_SET_SCHEMA( MultiHeadAttentionTypeAndShapeInference(ctx, 6); })); +constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( +Group Query Self/Cross Attention. + +Supports different number of heads for q and kv. Only supports causal or local attention. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupQueryAttention, 1, + OpSchema() + .SetDoc(GroupQueryAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) + .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) + .Input(0, + "query", + "Query with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "key", + "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", + "T") + .Input(2, + "value", + "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", + "T") + .Input(3, + "past_key", + "past state key with support for format BNSH. When past_key uses same tensor as present_key" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(4, + "past_value", + "past state value with support for format BNSH. When past_value uses same tensor as present_value" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(5, + "seqlens_k", + "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "M") + .Input(6, + "total_sequence_length", + "Scalar tensor of total sequence length (past + new).", + "M") + .Output(0, + "output", + "3D output tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Output(1, + "present_key", + "present state key with support for format BNSH. When past_key uses same tensor as present_key" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T") + .Output(2, + "present_value", + "present state value with support for format BNSH. When past_value uses same tensor as present_value" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + GroupQueryAttentionTypeAndShapeInference(ctx, 3); + })); + constexpr const char* Longformer_Attention_doc = R"DOC( Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens @@ -994,6 +1126,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DecoderAttentionTypeAndShapeInference(ctx); })); +constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +that are multiplied to query and key before the inner product of query and key is taken. +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbedding, 1, + OpSchema() + .SetDoc(RotaryEmbedding_ver1_doc) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1.0", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "input", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", + "T") + .Input(1, + "position_ids", + "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)", + "M") + .Input(2, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Input(3, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Output(0, + "output", + "tensor with same shape as input.", + "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 9e63e0d5e83f6..59adfc523c860 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -79,6 +79,397 @@ void RegisterCollectiveOps() { .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateShapeAndTypeFromFirstInput(ctx); }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(0, "Y", "Matrix multiply results from A * B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint( + "T", + { + "tensor(float16)", + "tensor(float)", + }, + "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input( + 0, + "data", + "Tensor of data to extract slices from.", + "T", + OpSchema::Single, + true, + 1, + OpSchema::Differentiable) + .Input( + 1, + "starts", + "1-D tensor of starting indices of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 2, + "ends", + "1-D tensor of ending indices (exclusive) of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 3, + "axes", + "1-D tensor of axes that `starts` and `ends` apply to. Negative value means counting dimensions " + "from the back. Accepted range is [-r, r-1] where r = rank(data). Behavior is undefined if an " + "axis is repeated.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 4, + "steps", + "1-D tensor of slice step of corresponding axis in `axes`. " + "Negative value means slicing backward. 'steps' cannot be 0. " + "Defaults to 1s.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") + .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReshape) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr( + "allowzero", + "(Optional) By default, when any value in the 'shape' input is equal to zero " + "the corresponding dimension value is copied from the input tensor dynamically. " + "allowzero=1 indicates that if any value in the 'shape' input is set to zero, " + "the zero value is honored, similar to NumPy.", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "Specified shape for output.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedUnsqueeze) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "axes", + "A 1-D tensor indicates the axes to add.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSqueeze) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "axes", + "A 1-D tensor indicates the axes to add.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a79203a94a3a7..4c0d78f0ee297 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -415,6 +415,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // output 0 (sequences) shape: (batch_size, num_return_sequences, max_length) // output 1 (sequences_scores) shape: (batch_size, num_return_sequences) // output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size) + // output 3 (cross_attention): shape: (batch_size, num_return_sequences, Layers, Heads, max_length, Frames) if (!hasInputShape(ctx, 0)) { return; } @@ -1060,6 +1061,78 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); +ONNX_MS_OPERATOR_SET_SCHEMA( + UnfoldTensor, 1, + OpSchema() + .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " + "Step between two slices is given by step. " + "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " + "the returned tensor will be (sizedim - size) / step + 1. " + "An additional dimension of size size is appended in the returned tensor.") + .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) + .Attr("size", "specify the size", AttributeProto::INT) + .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "input tensor", "T") + .Output(0, "output", "Output tensor.", "T") + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasInputShape(ctx, 0)) return; + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + int64_t dim = getAttribute(ctx, "dim", -1); + dim = HandleNegativeAxis(dim, rank); + if (!input_shape.dim(static_cast(dim)).has_dim_value()) { + return; + } + int64_t dim_size = input_shape.dim(static_cast(dim)).dim_value(); + + const int64_t step = getAttribute(ctx, "step", -1); + if (step <= 0) { + fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") + } + int64_t size = -1; + auto size_proto = ctx.getAttribute("size"); + if (!(size_proto)) { + fail_shape_inference("size attribute in UnfoldTensor not specified!") + } + size = size_proto->i(); + if (size > dim_size || size <= 0) { + fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int d = 0; d < rank; d++) { + if (d == dim) { + output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); + } else { + *output_shape.add_dim() = input_shape.dim(d); + } + } + output_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, output_shape); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicTimeWarping, 1, + OpSchema() + .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" + "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " + "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") + .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") + .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") + .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); + ONNX_NAMESPACE::TensorShapeProto resultShape; + resultShape.add_dim()->set_dim_value(2); + resultShape.add_dim(); + updateOutputShape(ctx, 0, resultShape); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") @@ -1110,6 +1183,122 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, BeamSearchShapeInference(ctx); })); +ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, + OpSchema() + .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) + .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("init_decoder", + "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " + "This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs", + AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Attr("vocab_size", + "Size of the vocabulary. " + "If not provided, it will be inferred from the decoder subgraph's output shape", + AttributeProto::INT, static_cast(-1)) + .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") + .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") + .Input(5, "length_penalty", + "Exponential penalty to the length. Default value 1.0 means no penalty." + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Shape is (1,)", + "T", OpSchema::Optional) + .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) + .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) + .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) + .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) + .Input(12, "cross_qk_layer_head", + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", + "I", OpSchema::Optional) + .Input(13, "extra_decoding_ids", + "Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len)." + "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " + "are treated as stop of the extra_decoding_ids for corresponding batch.", + "I", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) + .Output(2, "scores", + "Processed beam scores for each vocabulary token at each generation step." + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", + "T", OpSchema::Optional) + .Output(3, "cross_qk", + "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", + "V", OpSchema::Optional) + .Output(4, "non_speech_probs", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." + "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") + .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") + .TypeConstraint("V", {"tensor(float)"}, "Constrain cross_qk to float32 tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + BeamSearchShapeInference(ctx); + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::updateOutputElemType(ctx, 3, ONNX_NAMESPACE::TensorProto::FLOAT); + } + if (!hasInputShape(ctx, 0)) { + return; + } + auto& input_ids_shape = getInputShape(ctx, 0); + auto& input_ids_dims = input_ids_shape.dim(); + int64_t batch_size = input_ids_dims[0].dim_value(); + int64_t sequence_length = input_ids_dims[1].dim_value(); + + const auto max_length = ctx.getInputData(1); + const auto num_return_sequences = ctx.getInputData(4); + if (max_length == nullptr || num_return_sequences == nullptr) { // not initializer + return; + } + int max_length_value = 0; + if (!ParseScalar(max_length, max_length_value) || max_length_value <= 0) { + fail_shape_inference("Failed to parse max_length or it is not positive integer scalar"); + } + + int num_return_sequences_value = 0; + if (!ParseScalar(num_return_sequences, num_return_sequences_value) || num_return_sequences_value <= 0) { + fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar"); + } + + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::TensorShapeProto cross_attn_shape; + cross_attn_shape.add_dim()->set_dim_value(batch_size); + cross_attn_shape.add_dim()->set_dim_value(num_return_sequences_value); + cross_attn_shape.add_dim(); // num of layer is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim(); // num of head is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim()->set_dim_value(max_length_value); + cross_attn_shape.add_dim()->set_dim_value(sequence_length); + updateOutputShape(ctx, 3, cross_attn_shape); + } + if (ctx.getNumOutputs() > 4) { + ONNX_NAMESPACE::TensorShapeProto non_speech_probs_shape; + non_speech_probs_shape.add_dim()->set_dim_value(batch_size); + updateOutputShape(ctx, 4, non_speech_probs_shape); + } + })); + ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, OpSchema() .SetDoc("Greedy Search for text generation.") @@ -1186,6 +1375,27 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, GreedySearchShapeInference(ctx); })); +constexpr const char* MoE_ver1_doc = R"DOC( + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + )DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, + OpSchema() + .SetDoc(MoE_ver1_doc) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") + .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") + .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") @@ -2384,6 +2594,154 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_TYPES \ + { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#else +#define GEMM_FLOAT8_TYPES \ + { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#endif + +ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, + OpSchema() + .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") + .Attr( + "transA", + "Whether A should be transposed. Float 8 only supprted transA=0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed. Float 8 only supprted transB=1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for the product of input bias C.", + AttributeProto::FLOAT, + 0.0f) + .Attr( + "dtype", + "Output Type. Same definition as attribute 'to' for operator Cast.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "Activation function, RELU or GELU or NONE (default).", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "TA") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "TB") + .Input( + 2, + "C", + "Input tensor C.", + "TC", + OpSchema::Optional) + .Input( + 3, + "scaleA", + "Scale of tensor A if A is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 4, + "scaleB", + "Scale of tensor B if B is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 5, + "scaleY", + "Scale of the output tensor if A or B is float 8.", + "TS", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (M, N).", "TR") + .TypeConstraint( + "TA", + GEMM_FLOAT8_TYPES, + "Constrain type to input A.") + .TypeConstraint( + "TB", + GEMM_FLOAT8_TYPES, + "Constrain type to input B.") + .TypeConstraint( + "TC", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain type to input C.") + .TypeConstraint( + "TR", + GEMM_FLOAT8_TYPES, + "Constrain type to result type.") + .TypeConstraint("TS", {"tensor(float)"}, + "Constrain type for all input scales (scaleA, scaleB, scaleY).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto transAAttr = ctx.getAttribute("transA"); + bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) { + fail_shape_inference("First input does not have rank 2"); + } + if (second_input_shape.dim_size() != 2) { + fail_shape_inference("Second input does not have rank 2"); + } + updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); + })); + +static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, + int64_t K, + int64_t N, + bool transB) { + int input_a_idx = 0; + if (!hasInputShape(ctx, input_a_idx)) { + return; + } + + const auto& a_shape = ctx.getInputType(input_a_idx)->tensor_type().shape(); + if (a_shape.dim_size() == 0) { + fail_shape_inference("Input tensors of wrong rank (0)."); + } + + // TODO: check B shape + + const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); + ONNX_NAMESPACE::TensorShapeProto resultShape; + if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) { + fail_shape_inference("Incompatible dimensions for matrix multiplication"); + } + + for (int i = 0; i < a_shape.dim_size() - 1; ++i) { + *resultShape.add_dim() = a_shape.dim(i); + } + resultShape.add_dim()->set_dim_value(transB ? N : K); + + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; +} + void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); @@ -2844,6 +3202,83 @@ void RegisterContribSchemas() { propagateElemTypeFromInputToOutput(ctx, 0, 0); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(EPContext) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("Onnx node container for EP context.") + .Attr( + "main_context", + "Usually each single EPContext associate with a graph partition." + "But for some case like QNN, it has single EPContext contains all partitions." + "In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context." + "The path is relative to this Onnx file. Default is 1.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "ep_cache_context", + "payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "embed_mode", + "1: indicate ep_cache_context is the context content. 0: indicate ep_cache_context is the file path to the context content." + "The path is relative to this Onnx file. Default is 1.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "ep_sdk_version", + "(Optional) SDK version used to convert the model.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "partition_name", + "(Optional) partitioned graph name.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr( + "source", + "(Optional) the source used to generate the engine/context cache file. Ort EP or native SDK tool chain", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE) + .AllowUncheckedAttributes() + .Input( + 0, + "inputs", + "List of tensors for inputs", + "T", + OpSchema::Variadic, + true, + 1, + OpSchema::NonDifferentiable) + .Output( + 0, + "outputs", + "One or more outputs, list of tensors for outputs", + "T", + OpSchema::Variadic, + true, + 1, + OpSchema::NonDifferentiable) + .TypeConstraint( + "T", + {"tensor(int8)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(uint8)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(float16)", + "tensor(float)", + "tensor(double)"}, + "Constrain input and output types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + }); + static const char* BitmaskDropout_ver1_doc = R"DOC( BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs: output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout. @@ -2895,6 +3330,119 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t } }); + static const char* MatMulNBits_ver1_doc = R"DOC( +MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + +Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: +- n_blocks_per_col = (K + block_size - 1) / block_size +- blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBits_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional data blob", "T2") + .Input(2, "scales", "quantization scale", "T1") + .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); + }); + + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Attr("training_mode", + "Indicate if the ops run in training_mode, by default, False.", + AttributeProto::INT, + static_cast(0)) + .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.", + AttributeProto::INT, static_cast(1)) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + bool transB = getAttribute(ctx, "transB", 1) != 0; + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) @@ -3003,7 +3551,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs } #endif -#ifdef USE_MPI +#ifdef ORT_USE_NCCL RegisterCollectiveOps(); #endif } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.h b/onnxruntime/core/graph/contrib_ops/contrib_defs.h index 2d5b4f8e76cc2..5b3904669f9fc 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.h +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.h @@ -53,7 +53,7 @@ void RegisterContribSchemas(); void RegisterNchwcSchemas(); void RegisterQuantizationSchemas(); -#if defined(USE_MPI) +#if defined(ORT_USE_NCCL) void RegisterCollectiveOps(); #endif diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 3ce7c40e754dc..c8960578f9e3d 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -90,13 +90,18 @@ void RegisterNHWCSchemaWithActivation(const RegistrationFunc& f, ::ONNX_NAMESPAC void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& fn) { // if the operator may be fused with an activation, use the WITH_ACTIVATION variant to add optional attributes // for the activation parameters. - // For now we only register operators from opset 11 on. Models can easily have their opset updated using ONNX tools + // We mainly register operators from opset 11 on . Models can easily have their opset updated using ONNX tools // so supporting older opsets is unnecessary. + // Older opsets are included on a per-operator basis as needed. // NOTE: This should be in sync with GetLayoutSensitiveOps in - // /onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc + // /onnxruntime/core/optimizer/transpose_optimization/transpose_optimizer.cc + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 7); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 10); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 11); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 19); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 7); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 9); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 15); @@ -106,16 +111,18 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -181,7 +190,9 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); #endif fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -194,8 +205,10 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -203,11 +216,14 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7477f48088a15..a266c9ab04a2e 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -373,7 +373,8 @@ class Inliner { // Replace given name with a unique version of the name, and cache the // renaming-binding in current scope. void make_unique(std::string& name) { - auto new_name = prefix_ + name; + auto new_name{prefix_}; + new_name.append("_").append(name); auto& current_scope = rename_scopes_.back(); current_scope[name] = new_name; name = std::move(new_name); @@ -410,7 +411,7 @@ class Inliner { std::string rename_as = actuals.Get(i); if constexpr (isOutput) { if (rename_as.empty()) - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) @@ -420,7 +421,7 @@ class Inliner { std::string& formal = *formals.Mutable(i); std::string rename_as; if constexpr (isOutput) { - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) @@ -431,7 +432,7 @@ class Inliner { // Process a node: void transform(NodeProto& n) { if (!n.name().empty()) - n.set_name(prefix_ + n.name()); + n.set_name(prefix_ + "_" + n.name()); for (auto& x : *n.mutable_input()) { rename(x, false); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 383c1d689d3c3..d489a59c4b798 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -860,18 +860,18 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_e } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -void Node::Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, +void Node::Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain) { + std::string_view domain) { name_ = name; op_type_ = op_type; description_ = description; - definitions_.input_defs = input_args; - definitions_.output_defs = output_args; + definitions_.input_defs.assign(input_args.begin(), input_args.end()); + definitions_.output_defs.assign(output_args.begin(), output_args.end()); domain_ = domain; can_be_saved_ = true; priority_ = 0; @@ -984,6 +984,7 @@ bool Node::ClearAttribute(const std::string& attr_name) { graph_->SetGraphProtoSyncNeeded(); return attributes_.erase(attr_name) > 0; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) int Node::PruneRemovableAttributes(gsl::span removable_attributes) { @@ -1145,7 +1146,8 @@ Graph::Graph(const Model& owning_model, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const logging::Logger& logger, bool strict_shape_type_inference) - : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} + : Graph(owning_model, graph_proto, domain_to_version, ir_version, + schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, @@ -3261,8 +3263,8 @@ Node& Graph::AddNode(const std::string& name, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { - std::vector inputs; - std::vector outputs; + InlinedVector inputs; + InlinedVector outputs; inputs.resize(input_args.size()); outputs.resize(output_args.size()); int i = 0; @@ -3907,7 +3909,8 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std // kernel lookup works as per usual, if not using an existing schema. if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) { ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node), - "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType()); + "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType(), + " SinceVersion:", fused_node.SinceVersion()); } else if (IndexedSubGraph::SourceOfSchema::REUSE_OR_CREATE == sub_graph.schema_source) { auto schema_key = GenerateSchemaKey(sub_graph); if (reusable_fused_schema_map_.count(schema_key) == 0) { @@ -4019,69 +4022,424 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, return fused_node; } -Status Graph::InlineFunction(Node& callnode) { - const auto& model_path = ModelPath(); - auto output_edges = callnode.GetRelationships().output_edges; - for (const auto& output_edge : output_edges) { - RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); +Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& node_proto, + std::optional new_name) { + const gsl::not_null tensor{graph_proto_->add_initializer()}; + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, ModelPath(), *tensor, node_proto.output(0))); + + if (new_name.has_value()) { + tensor->set_name(std::string(new_name.value())); } - // create a uniq_identifier to append to every node name and intermediate input\outputs - // to make sure there are no unintended duplicates - std::stringstream ss; - ss << "_inline_" << callnode.OpType(); - auto uniq_identifier = GenerateNodeName(ss.str()); - // Replace a (function-call) node by an inlined graph. - if (!callnode.GetFunctionBody()) { - // This is the normal use-case: inlining a FunctionProto (representing - // a model-local function or a schema-defined function). - FunctionProto inlined_fp; - ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined."); - function_utils::Specialize(inlined_fp, callnode, uniq_identifier); + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), + " conflicts with graph initializer. Check that the node names have been made unique."); + if (GetNodeArg(tensor->name()) == nullptr) { + TypeProto t{TypeProtoFromTensorProto(*tensor)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); + } + +#if !defined(DISABLE_SPARSE_TENSORS) + if (node_proto.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif - auto to_node_arg = [this](const std::string& name) { - return &this->GetOrCreateNodeArg(name, nullptr); - }; + return Status::OK(); +} - // Process constant nodes first and create NodeArg for these as they become initializers - // It is important for the initializers to have NodeArg created, first they are needed - // if the initializer is unused and removed, second if the node depends on the initializer, - // we can have Type attached to it. - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() == kConstant) { - // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(inlined_node, model_path, *tensor, inlined_node.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined function: ", - inlined_fp.name(), " conflicts with graph initializer. Check Specializing code."); - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); +static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap& name_to_nodearg, + Graph& graph) { + for (auto& node : graph.Nodes()) { + if (node.ContainsSubgraph()) { + for (auto& [name, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) { + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); } } - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() != kConstant) { - InlinedVector inputs; - InlinedVector outputs; + // NodeArgs need to be updated + for (auto& input_def : node.MutableInputDefs()) { + if (input_def->Exists()) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); + } + } + } + } +} + +Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger) { + static const std::string then_branch{"then_branch"}; + static const std::string else_branch{"else_branch"}; + Graph* sub_graph; + if (condition_value) { + sub_graph = if_node.GetMutableGraphAttribute(then_branch); + } else { + sub_graph = if_node.GetMutableGraphAttribute(else_branch); + } + + if (sub_graph == nullptr) { + auto str = MakeString("Unable to constant fold If node: '", if_node.Name(), "' Unable to fetch: ", + (condition_value ? then_branch : else_branch)); + LOGS(logger, WARNING) << str; + return Status::OK(); + } + + Graph& graph_to_inline = *sub_graph; + + std::string unique_id{"_if_"}; + if (condition_value) { + unique_id.append(then_branch); + } else { + unique_id.append(else_branch); + } + + unique_id = GenerateNodeName(unique_id); + + auto make_unique = [&unique_id](const std::string& name) { + return unique_id + '_' + name; + }; + + // Check if the name is an input or implicit input. + // These are not renamed, and we do not need to adjust subgraphs for them. + // Implicit inputs would cover both If node input and implicit inputs. + // Reason: there are no explicit inputs to the subgraphs, and the subgraph's + // implicit inputs must be covered by the implicit inputs of the If node. + InlinedHashMap outer_scope_values; + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); + outer_scope_values.reserve(if_implicit_inputs.size()); + + for (auto* input : if_implicit_inputs) { + const auto& name = input->Name(); + ORT_IGNORE_RETURN_VALUE(outer_scope_values.emplace(name, input)); + } + + // Name mapping from the graph to inline to the graph we are inlining into + // we also use this to process any subgraphs in the graph we are inlining + InlinedHashMap name_to_nodearg; + + // We are going to map the outputs of the graph to inline to the outputs of the If node. + // They are assumed to be in the same order. + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); + for (size_t i = 0; i < graph_output_defs.size(); ++i) { + name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); + } + + // Move initializers from the subgraph to the destination graph. + for (int i = 0, limit = graph_to_inline.graph_proto_->initializer_size(); i < limit; ++i) { + auto* initializer = graph_to_inline.graph_proto_->mutable_initializer(i); + const std::string src_name = initializer->name(); + +#if !defined(DISABLE_SPARSE_TENSORS) + bool has_sparse_origin = false; + if (!graph_to_inline.sparse_tensor_names_.empty()) { + auto hit = graph_to_inline.sparse_tensor_names_.find(src_name); + if (hit != graph_to_inline.sparse_tensor_names_.cend()) { + has_sparse_origin = true; + // Erase the entry that will be invalidated + graph_to_inline.sparse_tensor_names_.erase(hit); + } + } +#endif + + graph_to_inline.name_to_initial_tensor_.erase(src_name); + const gsl::not_null tensor{graph_proto_->add_initializer()}; + *tensor = std::move(*initializer); + + // Check if this is an output of the graph + auto hit = name_to_nodearg.find(src_name); + if (hit != name_to_nodearg.cend()) { + // We rename it to If node output. + tensor->set_name(hit->second->Name()); + } else { + NodeArg* node_arg = graph_to_inline.GetNodeArg(src_name); + assert(node_arg != nullptr); + auto new_name = GenerateNodeArgName(make_unique(src_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, node_arg->TypeAsProto()); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(src_name, &new_arg)); + tensor->set_name(std::move(new_name)); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Initializer name: ", tensor->name(), " from graph: ", + graph_to_inline.Name(), " conflicts with graph initializer. Check name generation above."); + +#if !defined(DISABLE_SPARSE_TENSORS) + if (has_sparse_origin) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + } + + // Look up nodes that would be providing input to our nodes (implicit and explicit) + // and any nodes that take the output of our nodes (used to be If output) + // Map of NodeArg name to pair of Node* and input index in the destination node + using NodeAndIndex = std::pair, int>; + using ArgNameToNodeMap = InlinedHashMap; + ArgNameToNodeMap input_args; + // Map of NodeArg name to pair of Node* and output index in the source node. + ArgNameToNodeMap output_args; + + auto map_defs = [](Node& node, ArgNameToNodeMap& map, bool input) { + const auto defs = (input) ? node.InputDefs() : node.OutputDefs(); + map.reserve(map.size() + defs.size()); + int arg_pos = -1; + for (auto* node_arg : defs) { + ++arg_pos; + if (node_arg->Exists()) { + map.emplace(node_arg->Name(), std::make_pair(&node, arg_pos)); + } + } + }; + + const bool is_this_main_graph = (parent_graph_ == nullptr); + // Map the inputs and outputs of the If node to the nodes in the graph to inline. + if (!is_this_main_graph) { + for (auto& node : Nodes()) { + if (node.Index() == if_node.Index()) { + continue; + } + map_defs(node, input_args, true); + map_defs(node, output_args, false); + } + } + + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); + // We want to make sure we get nodes in topological order + // because Constant folding may cause the nodes appear in + // a different order. + InlinedVector new_nodes; + GraphViewer graph(graph_to_inline); + for (const auto node_idx : graph.GetNodesInTopologicalOrder()) { + // GraphViewer filters out nullptrs + auto* node = graph_to_inline.GetNode(node_idx); + assert(node->OpType() != kConstant); + + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + // Check if this is one of the implicit graph inputs + // then re-assign the def to the outer scope value. + const auto& input_name = input_def->Name(); + auto outer_hit = outer_scope_values.find(input_name); + if (outer_hit != outer_scope_values.cend()) { + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); + } else { + auto hit = name_to_nodearg.find(input_name); + if (hit != name_to_nodearg.cend()) { + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); + } else { + ORT_THROW("Node's: ", node->Name(), " input: ", input_name, + " is not If node's input or previous node output in this subgraph"); + } + } + } else { + new_input_defs.push_back(non_existing_arg); + } + } + + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } + } else { + new_output_defs.push_back(non_existing_arg); + } + } + + const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); + Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), + new_input_defs, + new_output_defs, + nullptr, + node->Domain()); - for (const auto& tensor_name : inlined_node.input()) - inputs.push_back(to_node_arg(tensor_name)); + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; - for (const auto& tensor_name : inlined_node.output()) - outputs.push_back(to_node_arg(tensor_name)); + if (!is_this_main_graph) { + map_defs(new_node, input_args, true); + map_defs(new_node, output_args, false); + new_nodes.push_back(&new_node); + } + + if (node->ContainsSubgraph()) { + auto& subgraphs = node->MutableSubgraphs(); + + // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. + int renames_subgraph_names = 0; + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + input_def = hit->second; + ++renames_subgraph_names; + } + } + + for (auto& subgraph : subgraphs) { + if (renames_subgraph_names > 0) { + // We need to rename the subgraph node names + // because they may refer to the implicit inputs + // that were renamed. + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); + } + subgraph->parent_node_ = &new_node; + subgraph->parent_graph_ = this; + } + + new_node.MutableSubgraphs() = std::move(subgraphs); + new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); + } + + new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); + } - onnxruntime::NodeAttributes new_attr_map; - new_attr_map.reserve(inlined_node.attribute_size()); - for (const auto& node_attr : inlined_node.attribute()) { - onnx::AttributeProto attr_copy = node_attr; - new_attr_map[node_attr.name()] = std::move(attr_copy); + // Let's rebuild local connections, so next time a GraphViewer is able to perform topological sort. + // We only need to do so if this graph is not the main graph, because the main graph is going to resolve + // and it is not possible to inline the same nodes again. + if (!is_this_main_graph) { + for (auto* node : new_nodes) { + int arg_pos = -1; + for (auto* input_def : node->InputDefs()) { + ++arg_pos; + auto hit = output_args.find(input_def->Name()); + if (hit != output_args.cend()) { + // The input to this node is an output from a previous node in this graph. + // Create relationship between this node (node), and the node providing the output (output_node). + const auto& [producer, src_idx] = hit->second; + AddEdge(producer->Index(), node->Index(), src_idx, arg_pos); } - AddNode(inlined_node.name(), inlined_node.op_type(), - inlined_node.doc_string(), inputs, outputs, &new_attr_map, inlined_node.domain()); } + + // Check if any of the outputs for inlined nodes are inputs to other nodes in the graph. + // (outputs of If node) + arg_pos = -1; + for (auto& output_def : node->OutputDefs()) { + ++arg_pos; + auto hit = input_args.find(output_def->Name()); + if (hit != input_args.cend()) { + // The output of this node is an input to another node in this graph. + // Create relationship between this node (node), and the node using the input (input_node). + const auto& [consumer, dst_idx] = hit->second; + AddEdge(node->Index(), consumer->Index(), arg_pos, dst_idx); + } + } + } + } + + LOGS(logger, INFO) << "Constant folded (inlined) " << (condition_value ? then_branch : else_branch) + << " for If node: " << if_node.Name(); + + return Status::OK(); +} + +Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { + auto to_node_arg = [this](const std::string& name) { + return &this->GetOrCreateNodeArg(name, nullptr); + }; + + // Process constant nodes first and create NodeArg for these as they become initializers + // It is important for the initializers to have NodeArg created, first they are needed + // if the initializer is unused and removed, second if the node depends on the initializer, + // we can have Type attached to it. + InlinedVector non_constant_nodes; + non_constant_nodes.reserve(func_to_inline.node_size()); + for (const auto& inlined_node : func_to_inline.node()) { + if (inlined_node.op_type() == kConstant) { + // Copy constant nodes _value to name_to_initial_tensor_ + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(inlined_node, std::nullopt)); + } else { + non_constant_nodes.push_back(&inlined_node); } + } + + for (const auto* inlined_node : non_constant_nodes) { + InlinedVector inputs; + InlinedVector outputs; + + for (const auto& tensor_name : inlined_node->input()) + inputs.push_back(to_node_arg(tensor_name)); + + for (const auto& tensor_name : inlined_node->output()) + outputs.push_back(to_node_arg(tensor_name)); + + onnxruntime::NodeAttributes new_attr_map; + new_attr_map.reserve(inlined_node->attribute_size()); + for (const auto& node_attr : inlined_node->attribute()) { + new_attr_map.insert_or_assign(node_attr.name(), node_attr); + } + ORT_IGNORE_RETURN_VALUE(AddNode(inlined_node->name(), inlined_node->op_type(), + inlined_node->doc_string(), inputs, outputs, + &new_attr_map, inlined_node->domain())); + } + + return Status::OK(); +} + +Status Graph::InlineFunction(Node& callnode) { + // Remove output edges. Requirement for RemoveNode() below. + auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator + for (const auto& output_edge : output_edges) { + RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); + } + + // create a uniq_identifier to append to every node name and intermediate input\outputs + // to make sure there are no unintended duplicates + std::string base_uniq_identifier{"_inlfunc_"}; + base_uniq_identifier.append(callnode.OpType()); + const auto uniq_identifier = GenerateNodeName(base_uniq_identifier); + // Replace a (function-call) node by an inlined graph. + if (!callnode.GetFunctionBody()) { + // This is the normal use-case: inlining a FunctionProto (representing + // a model-local function or a schema-defined function). + ONNX_NAMESPACE::FunctionProto inlined_fp; + ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined."); + + // Make all the names unique and resolve nested graphs inputs to the outer scope. + function_utils::Specialize(inlined_fp, callnode, uniq_identifier); + + // In this case, global Resolve() will take care of everything. + ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp)); } else { // Uncommon scenario. Inlining a node representing a fused sub-graph. // TODO: Unclear that this feature is needed. Can this be removed? @@ -4115,15 +4473,7 @@ Status Graph::InlineFunction(Node& callnode) { // Copy constant nodes _value to name_to_initial_tensor_ ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; subgraph_node.ToProto(subgraph_node_proto); - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined subgraph: ", - subgraph.Name(), " conflicts with graph initializer. Check Specializing code."); - if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(subgraph_node_proto, std::nullopt)); } } diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 05747a7e5124d..b3935e69ad7b1 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -41,6 +41,35 @@ namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +void Model::RemoveLocalFunctionsProtos(const InlinedHashSet& retained) { + auto* local_functions = model_proto_.mutable_functions(); + if (retained.empty()) { + model_local_function_templates_maps_.clear(); + model_local_functions_.clear(); + local_functions->erase(local_functions->begin(), local_functions->end()); + } else { + const auto retained_end = retained.cend(); + for (auto it = model_local_functions_.begin(); + it != model_local_functions_.end();) { + if (retained.find(it->first) == retained_end) { + model_local_function_templates_maps_.erase(it->first); + it = model_local_functions_.erase(it); + } else { + ++it; + } + } + + for (auto it = local_functions->begin(); it != local_functions->end();) { + const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name()); + if (retained.find(function_id) == retained_end) { + it = local_functions->erase(it); + } else { + ++it; + } + } + } +} + static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024; Model::Model(const std::string& graph_name, @@ -95,10 +124,10 @@ Model::Model(const std::string& graph_name, for (auto& func : model_local_functions) { auto func_ptr = model_proto_.add_functions(); func_ptr->CopyFrom(func); - model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr); + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), + func_ptr); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -111,8 +140,8 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared @@ -203,6 +232,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, } } + // special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported + if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) { + auto onnx_version = domain_to_version.find(kOnnxDomain); + if (onnx_version != domain_to_version.end()) { + domain_to_version[kMSInternalNHWCDomain] = onnx_version->second; + } + } + auto domain_map = allow_official_onnx_release_only_final ? schema_registry->GetLastReleasedOpsetVersions(false) : schema_registry->GetLatestOpsetVersions(false); @@ -220,7 +257,6 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -233,9 +269,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = - model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique @@ -244,7 +278,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, logger, options.strict_shape_type_inference)); } -const InlinedHashMap& Model::GetModelLocalFunctionTemplates() const { +const NodeHashMap>& Model::GetModelLocalFunctionTemplates() const { return model_local_function_templates_maps_; } @@ -332,7 +366,7 @@ const Graph& Model::MainGraph() const noexcept { } #if !defined(ORT_MINIMAL_BUILD) -ModelProto Model::ToProto() { +ModelProto Model::ToProto() const { // We want to return back the original proto // To that end invoke const overload of ToGraphProto() // that returns by value and, therefore, allows us to filter @@ -346,7 +380,7 @@ ModelProto Model::ToProto() { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold) { + size_t initializer_size_threshold) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6bdb68dd734f0..4ce6660b794bc 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -139,7 +139,7 @@ class Model { // Returns empty string if not specified. const std::string GraphDocString() const; - const InlinedHashMap& GetModelLocalFunctionTemplates() const; + const NodeHashMap>& GetModelLocalFunctionTemplates() const; #else // Get model's IR version. @@ -182,14 +182,14 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) // Get model's serialization proto data. - ONNX_NAMESPACE::ModelProto ToProto(); + ONNX_NAMESPACE::ModelProto ToProto() const; // Get model's serialization proto data. // Save initializer larger than the given threshold (in bytes) into an external binary file // with the given name. This function is useful to avoid hitting the size limit of protobuf files. ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold); + size_t initializer_size_threshold) const; #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); @@ -291,6 +291,13 @@ class Model { common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; + /// + /// Frees local function definitions in the model, excluding those in the `retained` set. + /// Called from GraphPartitioner::InlineFunctionsAOT. + /// + /// contains function IDs that should not be removed. + void RemoveLocalFunctionsProtos(const InlinedHashSet& retained); + #endif // !defined(ORT_MINIMAL_BUILD) static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Model& fbs_model, @@ -312,14 +319,12 @@ class Model { // this map will be used for the local functions' schema's type/shape inference. // This container is used by ONNX code and must be an std::unordered_map. std::unordered_map model_local_functions_; - // this is the container that host the generated schemas for model local functions. - // the generated schemare will be used for graph resolving and type/shape inference. - // those schemas' type/shape inference will reference to the model_local_functions_ as context, - // so need to keep them with same lifetime. - InlinedVector> model_local_function_templates_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. - InlinedHashMap model_local_function_templates_maps_; + // Defined as a node based map so the memory is released when not all of the functions + // are inlined and removed. + NodeHashMap> model_local_function_templates_maps_; + #else // properties that would normally come from ModelProto std::string producer_version_; diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md new file mode 100644 index 0000000000000..7e8d30cd1805b --- /dev/null +++ b/onnxruntime/core/mickey/README.md @@ -0,0 +1,6 @@ +# About Mickey + +Playful name for a template library of high performance cuda code that +are often shared by various AI operators. The intention is to make this +header files only, with no binary impact unless it is instantiated +where it is needed. diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h new file mode 100644 index 0000000000000..e291ab39e8aa3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h @@ -0,0 +1,325 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * prepack_sm80.h + * + * Abstract: + * Prepack weights and quantization parameters (scales and offsets) for + * GEMM, where activations are fp16 or bf16, and weights are block-wise + * 4b quantized values, specifically for Ampere GPUs. + * + * Prepacking enables faster loading of weights and quantization parameters + * into tensor cores, and faster dequantization of weights. + * + * Only supports fp16 for now, bfloat16 support will be added later. + */ + +#pragma once + +#include "core/common/common.h" +#include "core/util/matrix_layout.h" + +namespace onnxruntime { +namespace cuda { + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int block_size, + int qbits, + bool Columnwise, + bool ExtraBoundsCheck = false> +struct BlockwiseQuantization { + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(sizeof(ElementT) == 2, "Only 16b floating point types are supported!"); + + using QuantBlocking = + std::conditional_t, + MatrixShape<1, block_size>>; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + // We pack 4 weights into one 16b element, so we can leverage cutlass tile iterators + // for async share memory loading, and minimizing bank conflict during matrix loading + using ElementWPack = ElementT; + using LayoutWPack = ColumnMajorLayout; // <- layout of packed weight, must be column major + + // Current Ampere kernel use 8b zero point, need to shrink it to 4b in the future + using ElementQOffset = uint8_t; + + // Layout of the quantization parameters (scales and zero points) + // Major on the dimension that has the most parameters per squarish weight block. + // E.g. for column-wise quantization, a [64, 64] block has [2, 64] parameters, + // where each row has more data, so we use row major layout so that warp threads + // can use less load instructions to load more parameters. + using LayoutQmeta = + typename std::conditional::type; + + /** + * @brief Get quantized weight tensor dimensions. + * Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation + * troubles. Since the layout is column major, we are packing 2 weights in a column + * into one int8 + */ + static inline auto get_quant_weights_shape(int rows, int columns) { + return make_Position(rows / 2, columns); + } + + static inline auto get_quant_meta_shape(int rows, int columns) { + return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + } + + /** + * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA + * instruction layout. + * + * The weight matrix is int4, yet we want to leverage existing fp16/bf16 + * tile loading and MMA layout code in CUTLASS. So we group 4 int4 into 2 + * bytes, pretending it's fp16. This grouping must be done in a way to be + * easily unpacked into tiles that match the MMA instruction layout. + * For MMA instruction <16, 8, 16>, each instruction processes 2 8x8 tiles, + * vertically stacked on the K dimension. And MmaTensorOpMultiplicandTileIterator + * loads a tile. + * + * So we stack 2x2 tiles on a 3rd dimeansion, and reshape them in a HWC fashion: + * T0, T2 + * T1, T3 + * ==> + * T0[0, 0], T1[0, 0], T2[0, 0], T3[0, 0] + * T0[1, 0], T1[1, 0], T2[1, 0], T3[1, 0] + * T0[2, 0], T1[2, 0], T2[2, 0], T3[2, 0] + * T0[3, 0], T1[3, 0], T2[3, 0], T3[3, 0] + * ... + * T0[0, 7], T1[0, 7], T2[0, 7], T3[0, 7] + * T0[1, 7], T1[1, 7], T2[1, 7], T3[1, 7] + * T0[2, 7], T1[2, 7], T2[2, 7], T3[2, 7] + * T0[3, 7], T1[3, 7], T2[3, 7], T3[3, 7] + * + * This pack a 8x16 int8 tile into a 16x8 int8 tile, i.e. a 8x8 16b tile + */ + static void prepack_weights( + int rows, + int columns, + const gsl::span& weights, // <- int4 weights, column major + const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer + ) { + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && + (rows % QuantBlocking::kRow) == 0 && + (columns % QuantBlocking::kColumn) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), + "Weight tensor shape mismatch!"); + ORT_ENFORCE(weights_prepacked.size() == weights.size(), + "Prepacked Weight tensor buffer should be the same size!"); + + const MatrixRef + tensor_weight(weights, make_Position(rows / 2, columns)); + const MatrixRef + tensor_weight_prepacked(weights_prepacked, make_Position(rows, columns / 2)); + + // TODO(fuchen)!! parallized this. + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } + } + + /** + * @brief We rearrange the values of the quantization scale and offset tensors + * to facilitate faster loading to tensor core, only 16b gemm, and (1,n) + * block quantization. + */ + static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; + + static void prepack_quant_scales( + size_t rows, + size_t columns, + const gsl::span& scales, // <- quant scales, column major layout + const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), + "Quantization scale tensor shape mismatch!"); + ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()), + "Prepacked quantization scale tensor buffer should be the same size!"); + + MatrixRef tensor_scale(scales, meta_shape); + MatrixRef tensor_scale_prepacked(scales_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row = 0; row < tensor_scale.shape()[0]; ++row) { + tensor_scale_prepacked.at(row, col) = tensor_scale.at(row, col); + } + } + } + } + + static void prepack_quant_offsets( + size_t rows, + size_t columns, + const gsl::span& offsets, // <- quant offsets, int4, column major layout + const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), + "Wrong buffer size for prepacked quantization offsets!"); + ORT_ENFORCE(offsets.size() == size_t(((meta_shape[0] + 1) / 2) * meta_shape[1]), + "Quantization offset tensor shape mismatch!"); + + MatrixRef + tensor_offset(offsets, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + MatrixRef tensor_offset_prepacked(offsets_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row_blk = 0; row_blk < meta_shape[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_offset_prepacked.at(row + 0, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_offset_prepacked.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format index 4a89ef98cf049..16ad8bd8a7234 100644 --- a/onnxruntime/core/mlas/.clang-format +++ b/onnxruntime/core/mlas/.clang-format @@ -2,10 +2,12 @@ BasedOnStyle: Google IndentWidth: 4 -ColumnLimit: 100 +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +AlignAfterOpenBracket: BlockIndent AlwaysBreakAfterReturnType: TopLevel AlwaysBreakTemplateDeclarations: Yes BinPackParameters: false BreakBeforeBraces: Linux ... - diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000000..7ea29eb091318 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,33 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_gemm_postprocessor.h + +Abstract: + + This module contains a base class for custom postprocessing following a + GEMM. + +--*/ + +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..316344ad8c214 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -21,6 +21,7 @@ Module Name: #pragma once #include "mlas.h" +#include "mlas_gemm_postprocessor.h" #include #include @@ -39,7 +40,7 @@ typedef enum { * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix * @param QType type of block quantization - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. */ @@ -53,11 +54,11 @@ MlasQ4GemmPackBSize( /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * + * * @param QType type of block quantization * @param PackedBuf destination buffer * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB( ); -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; - - /** * @brief Data parameters for Q4 GEMM routine * C = A * B + Bias @@ -229,3 +214,147 @@ MlasQ8Q4GemmBatch( const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into separate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantized matrix [q_rows, q_cols]. The quantized matrix + * is in column major layout, with bits packed on the column. + * + * @tparam T + * @tparam qbits + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols +*/ +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +/** + * @brief Compute the sizes of the quantized data and quantization parameter buffers. + * + * @param qbits The bit width of each quantized value. + * @param block_size The number of quantized values in a block. + * @param columnwise Whether a block contains values from a matrix column (true) or row (false). + * @param rows Number of matrix rows. + * @param columns Number of matrix columns. + * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. + * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. + * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. + * + * If the qbits or block_size values are unsupported the output sizes will be zero. + */ +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into separate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from separate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000000..9620dd42d1da9 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -0,0 +1,79 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr +); + +/** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +); diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000000..e18846c89030e --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S new file mode 100644 index 0000000000000..baf6e21e6ff06 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmU8X8KernelUmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmU8X8KernelUmmlaFunction Zero + QgemmU8X8KernelUmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..6c859e4e4f44b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -184,11 +184,17 @@ class MLASCPUIDInfo bool IsCurrentCoreArmv8NarrowLd() const { return false; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + private: MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -856,6 +862,8 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; @@ -882,14 +890,30 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_Q8Q4GEMM_DISPATCH; extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1021,6 +1045,8 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline @@ -1069,6 +1095,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 96bc1d8010bed..fec56c6ee063f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -52,6 +52,14 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -59,6 +67,9 @@ MLASCPUIDInfo::MLASCPUIDInfo() // raw hack! Need CPUIDInfo implementation for more precise detection has_fp16_ = has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); } #endif @@ -449,6 +460,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -458,12 +470,16 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we - // disable it there. - uint64_t isar0_el1; - asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; #else + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif @@ -476,6 +492,17 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } +#if defined(__linux__) + // + // Check if the processor supports ASIMD I8MM instructions. + // + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; + } +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 85c0d13006126..b5784ecb56d01 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -294,3 +294,790 @@ MlasQ4GemmUnPackB( return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); } } + + + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE +void +range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) +{ + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = ScaleT(scale_f); +} + +template +MLAS_FORCEINLINE +void +range2scale(float min, float max, ScaleT& scale) +{ + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = ScaleT(max / mid_fp); +}; + + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static + MLAS_FORCEINLINE + void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) + { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + static + MLAS_FORCEINLINE + void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { + int meta_rows; + int meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + + // quantized matrix is stored in column major, packed by column + q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + q_cols = meta_cols * QuantBlk::kColumn; + } + + static MLAS_FORCEINLINE void quantizedBufferSizes( + int rows, int columns, size_t& data_bytes, size_t& scale_num_elements, size_t* zero_point_bytes + ) + { + int meta_rows, meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + data_bytes = q_rows * q_cols; + scale_num_elements = meta_rows * meta_cols; + + if (zero_point_bytes) { + // this works for qbits == 4 but may need to be updated for other qbits values + *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; + } + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } + } + + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize( + ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + + + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape( + rows, columns, q_rows, q_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + default: + q_rows = 0; + q_cols = 0; + break; + } +} + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +) +{ + q_data_size_in_bytes = q_scale_num_elements = 0; + if (q_zero_point_size_in_bytes) { + *q_zero_point_size_in_bytes = 0; + } + + if (qbits == 4) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } + } +} + + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + T* scales, + uint8_t* zero_points, + const T* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +template +void +MlasDequantizeBlockwise( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index 1562f9c0b4236..b1b51dd53c4fc 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -90,7 +90,7 @@ MlasQ4GemmOperation( if (DataParams->OutputProcessor != nullptr) { DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc); } diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000000..c41f43ca22d18 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,964 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) +{ + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp new file mode 100644 index 0000000000000..3936154432ac7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp @@ -0,0 +1,967 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_ummla.cpp + +Abstract: + + This module implements ummla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON UMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_U8X8_KERNEL_UMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + uint8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a0 + lda * 2; + const uint8_t* a3 = a0 + lda * 3; + const uint8_t* a4 = a0 + lda * 4; + const uint8_t* a5 = a0 + lda * 5; + const uint8_t* a6 = a0 + lda * 6; + const uint8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + uint32x4_t RowSums0 = vmovq_n_u32(0); + uint32x4_t RowSums1 = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); + a4 += 16; + uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); + a5 += 16; + uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); + a6 += 16; + uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); + a7 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + uint64x2_t z4 = vzip1q_u64(v4, v5); + uint64x2_t z5 = vzip2q_u64(v4, v5); + uint64x2_t z6 = vzip1q_u64(v6, v7); + uint64x2_t z7 = vzip2q_u64(v6, v7); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); + vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); + vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); + vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + uint64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + uint64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + uint64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + uint64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + vst1q_u8(&d[32], vmovq_n_u8(0)); + vst1q_u8(&d[48], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); + vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + uint8_t* d = PaddedMatrixAData; + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + uint8x8_t BytesRow[8], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2]) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); + uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); + + uint8x16x2_t zw1 = vzipq_u8(v02, v13); + uint16x8x2_t zd1 = + vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); + + uint8x16x2_t zw2 = vzipq_u8(v46, v57); + uint16x8x2_t zd2 = + vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); + + uint32x4x2_t zd3 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); + uint32x4x2_t zd4 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); + vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); + vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); + + uint32x4_t ColSums0L_pada = vmovq_n_u32(0); + ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); + uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); + uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); + uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; + + uint32x4_t ColSums0H_pada = vmovq_n_u32(0); + ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); + uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); + uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); + uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); + + uint32x4_t ColSums1L_pada = vmovq_n_u32(0); + ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); + uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); + uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); + uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; + + uint32x4_t ColSums1H_pada = vmovq_n_u32(0); + ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); + uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); + uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); + uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + while (k >= 8) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + BytesRow[4] = vld1_u8(&b[ldb * 4]); + BytesRow[5] = vld1_u8(&b[ldb * 5]); + BytesRow[6] = vld1_u8(&b[ldb * 6]); + BytesRow[7] = vld1_u8(&b[ldb * 7]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); + BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); + BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); + BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); + BytesRow[7] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[64]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + const uint8_t* bcopy4 = &b[ldb * 4]; + const uint8_t* bcopy5 = &b[ldb * 5]; + const uint8_t* bcopy6 = &b[ldb * 6]; + const uint8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000000..f964b1affec31 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. +--*/ + +#include "sqnbitgemm.h" + +namespace +{ + +// Get quantization variant based on `BlkBitWidth` and `BlkLen`. +// Return -1 if the input values are unsupported. +int32_t +GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +{ + int32_t type = -1; + if (BlkBitWidth == 4 && BlkLen == 16) { + type = QuantVariant_BitWidth4_BlockSize16; + } else if (BlkBitWidth == 4 && BlkLen == 32) { + type = QuantVariant_BitWidth4_BlockSize32; + } else if (BlkBitWidth == 4 && BlkLen == 64) { + type = QuantVariant_BitWidth4_BlockSize64; + } else if (BlkBitWidth == 4 && BlkLen == 128) { + type = QuantVariant_BitWidth4_BlockSize128; + } else if (BlkBitWidth == 4 && BlkLen == 256) { + type = QuantVariant_BitWidth4_BlockSize256; + } + + return type; +} + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + Operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + if (QuantVariant == -1) { + return false; + } + + if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || + GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { + return false; + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000000..f8f7dcd43699f --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -0,0 +1,287 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes: + + - Declaration of the set of template functions used to implement a kernel + for a matrix/matrix multiplication, A*B, where A is a float matrix and B is + a n-bit quantized integer matrix (QNBitGemm). + + - A shared kernel driver function template, MlasSQNBitGemmOperation. + + - Kernel dispatch structure. + + The B matrix is block quantized, which means that its values are grouped + into blocks which each have one scale and optional zero point. Each + quantized value in B is n-bits wide. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +// +// Kernel implementation template declarations +// + +/** + * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1Kernel( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +/** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is block quantized and column major. + * This is equivalent to dequantizing B and then running + * MlasSgemmCopyPackB. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +); + +// +// MlasQNBitGemmOperation and helpers +// + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +MLAS_FORCEINLINE void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +template +MLAS_FORCEINLINE void MLASCALL +MlasSQNBitGemmOperation( + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasSQNBitGemmM1Kernel( + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasQNBitBlkDequantBForSgemm( + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +// +// Kernel dispatch structure. +// + +typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +enum QuantVariant { + QuantVariant_BitWidth4_BlockSize16, + QuantVariant_BitWidth4_BlockSize32, + QuantVariant_BitWidth4_BlockSize64, + QuantVariant_BitWidth4_BlockSize128, + QuantVariant_BitWidth4_BlockSize256, + QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. + // Its value is used as an array size. +}; + +struct MLAS_SQNBIT_GEMM_DISPATCH { + MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { + // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. + }; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..63afe57dd9137 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,489 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include "sqnbitgemm.h" + +#include + +#include +#include +#include + +// +// Hardware-specific kernel type. +// +struct MLAS_SQNBIT_GEMM_KERNEL_NEON { +}; + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts( + const float* ARowPtr, + const uint8_t* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const uint8_t* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // Manual conversion to float takes place in two steps: + // 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. + // This target float range is convenient because the 4-bit source values can be placed directly into the + // target float bits. + // 2. Subtract the conversion offset of 16 from the float result. + + // The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. + constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; + // sign|exponent|partial mantissa + // +|131: 2^4|~~~~ <- 4 bits go here + + const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + + float32x4_t acc[NCols]{}; + + const uint8_t* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + float offset[NCols]; // Includes zero point and float conversion offset of 16. + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const uint8_t zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); + offset[i] = 16.0f + zp; + }); + } else { + UnrolledLoop([&](size_t i) { + constexpr float zp = 8.0f; + offset[i] = 16.0f + zp; + }); + } + + constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + }); + + uint8x8_t bv_u8_unzipped[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + }); + + // dequantize B + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset (16) and zero point + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +} // namespace + +// +// MlasSQNBitGemmKernel and helpers. +// + +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1KernelNeon( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasSQNBitGemmM1Kernel( \ + const float* A, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + float* C, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB, \ + const float* Bias \ + ) \ + { \ + return MlasSQNBitGemmM1KernelNeon( \ + A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ + BlockStrideQuantB, Bias \ + ); \ + } + +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) + +#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL + +// +// MlasQNBitBlkDequantBForSgemm and helpers. +// + +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemmNeon( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + float* Dst = FpData; + + const uint8_t* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kklen = std::min(CountK - k, BlkLen); + + const uint8_t* b_data = + QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float b_s = QuantBScaleCol[k_blk_idx]; + const uint8_t b_z = + (QuantBZeroPointCol != nullptr) + ? ((k_blk_idx & 1) == 1) + ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 + : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { + const uint8_t b_packed = b_data[kk / 2]; + const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; + const float b_value = (b_byte - b_z) * b_s; + + Dst[(k + kk) * 16 + nn] = b_value; + } + } + + QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScaleCol += BlockStrideQuantB; + if (QuantBZeroPointCol != nullptr) { + QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + } + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); +} + +#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasQNBitBlkDequantBForSgemm( \ + float* FpData, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB \ + ) \ + { \ + MlasQNBitBlkDequantBForSgemmNeon( \ + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ + ); \ + } + +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) + +#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index ecdc5250ebf0e..dc5daf998d3be 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 094ea1e24dd92..9c98ed6d3e114 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -338,8 +338,8 @@ std::optional IsSupportedGather(Graph& graph, Node& node, auto axis = static_cast(node.GetAttributes().at("axis").i()); axis = axis < 0 ? axis + data_rank : axis; size_t dim_size = static_cast(indices_shape->dim_size()); - bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && - indices_shape->dim(0).dim_value() == 1); + bool is_single_value_1d_tensor = dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && + indices_shape->dim(0).dim_value() == 1; if (dim_size != 0 && !is_single_value_1d_tensor) { if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) && data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) { diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc index 716b027068ba1..23f7c45fba4ba 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc @@ -3,6 +3,7 @@ #ifdef ENABLE_TRAINING +#include #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -282,6 +283,23 @@ bool LayerNormalizationReshapeActor::PreCheck( return propagate_input_indices.size() > 0; } +bool LayerNormalizationReshapeActor::PostProcess( + Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + // When Reshape(from 3D to 2D, with the first two dimensions be merged) upstream a LayerNormalization, + // The axis attribute of LayerNormalization should be decreased by 1 if it is greater than 1. + if (axis > 1) { + auto new_axis = axis - 1; + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + return true; +} + template class SimplePointwiseReshapeActor; template class SimplePointwiseReshapeActor; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h index 05bcbabe9ba4c..de50a56fd8781 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h @@ -111,13 +111,11 @@ class UpStreamReshapeOperatorActorBase : public UpStreamOperatorActorBase { * So far, we don't have requirements to override PostProcess function. */ - bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, - const logging::Logger& /* logger */, - std::vector& /* propagate_input_indices */, - const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_reshape_infos */) { - return true; - } + virtual bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) = 0; }; // The inputs are broad-cast-able. The outputs should have the same shape (fully broadcasted shape) @@ -133,6 +131,14 @@ class SimplePointwiseReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -145,6 +151,14 @@ class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -157,6 +171,12 @@ class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override; }; /** diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index f46273f2680a9..e3a2f2d74c0d4 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -4,6 +4,7 @@ #include #include "core/optimizer/constant_folding.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" @@ -90,6 +91,45 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { return is_concrete_shape; // convert to constant if this is true } +// This function inlines the appropriate subgraph. It does not literally fold it. +static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Logger& logger, bool& folded) { + folded = false; + // First, find out which subgraph to inline + // We need to fetch the constant argument. + assert(if_node.InputDefs().size() == 1); + const auto* condition_def = if_node.InputDefs()[0]; + + // We need to check if the condition is a constant. + constexpr bool check_outer_scope_true = true; + const ONNX_NAMESPACE::TensorProto* initializer = + graph.GetConstantInitializer(condition_def->Name(), check_outer_scope_true); + if (initializer == nullptr) { + return Status::OK(); + } + + // This is a boolean initializer with a single element. + Initializer condition{*initializer}; + ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), + "' is expected to have a single boolean element"); + + const bool condition_value = *condition.data(); + + auto status = graph.InlineIfSubgraph(condition_value, if_node, logger); + + if (!status.IsOK()) { + LOGS(logger, WARNING) << "Unable to constant fold. InlineIfSubgraph failed " + << " node '" << if_node.Name() << "': " + << status.ErrorMessage(); + return status; + } + + graph_utils::RemoveNodeOutputEdges(graph, if_node); + graph.RemoveNode(if_node.Index()); + + folded = true; + return status; +} + Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { bool have_updated_nodes = false; GraphViewer graph_viewer(graph); @@ -118,7 +158,20 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } bool converted_to_constant = false; - if (node->OpType().compare("Shape") == 0) { + if (node->OpType().compare("If") == 0) { + // This process constant folds the If node only, + // but inlines the nodes of the corresponding branch graph. + // It does not convert the node to a constant in a common sense. + // We call it constant folding because the `If` node constant condition + // may enable us to inline the corresponding branch graph. + bool folded = false; + ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded)); + if (folded) { + // Node removal is done within ConstantFoldIfNode() + modified = true; + have_updated_nodes = true; + } + } else if (node->OpType().compare("Shape") == 0) { converted_to_constant = ConstantFoldShapeNode(graph, *node); } else { InitializedTensorSet constant_inputs; diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9b..d27603e4ab3a1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0f..6f90eaf07ef4d 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index b994028cbca13..4903bc1d6b961 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,8 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedVector node_args; + for (auto node_arg : graph.GetInputs()) { + if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + node_args.push_back(node_arg); + } + } + + for (auto entry : graph.GetAllInitializedTensors()) { + if (graph.GetConsumerNodes(entry.first).size() > 1) { + auto node_arg = graph.GetNodeArg(entry.first); + if (node_arg) { + node_args.push_back(node_arg); + } + } + } + for (auto node_index : node_topology_list) { auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion @@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - auto shape = node.MutableOutputDefs()[0]->Shape(); + node_args.push_back(node.OutputDefs()[0]); + } + + for (const NodeArg* node_arg : node_args) { + auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); @@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool first_edge = true; int64_t split_axis = 0; int64_t indices_n_dims = -1; - InlinedVector gather_outputs(output_count, nullptr); + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + InlinedVector gather_outputs(consumer_count, nullptr); InlinedVector> nodes_to_fuse; - for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { + for (auto consumer : consumers) { int64_t index, axis, dims; - if (!IsSupportedGather(graph, *it, index, axis, dims)) { + if (!consumer || consumer->InputDefs()[0] != node_arg || + !IsSupportedGather(graph, *consumer, index, axis, dims)) { can_fuse = false; break; } @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (axis < 0) axis += rank; if (first_edge) { auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(output_count)) { + if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { can_fuse = false; break; } @@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le can_fuse = false; break; } - if (index < 0) index += static_cast(output_count); - if (index < 0 || index >= static_cast(output_count) || gather_outputs[static_cast(index)]) { + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(it->Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.emplace_back(gather_node); gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; } @@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (!can_fuse) continue; ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); split_output_type.mutable_tensor_type()->set_elem_type(element_type); for (int64_t i = 0; i < rank; ++i) { if (i == split_axis) { @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le InlinedVector split_outputs; bool add_squeeze_node = indices_n_dims == 0; if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { split_outputs.emplace_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); } } - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(node.GetExecutionProviderType()); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (onnx_opset_version < 13) { if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } else { if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(output_count)); + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); } if (add_squeeze_node) { @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } diff --git a/onnxruntime/core/optimizer/graph_transformer.cc b/onnxruntime/core/optimizer/graph_transformer.cc index ba580b8105875..37093496a66fa 100644 --- a/onnxruntime/core/optimizer/graph_transformer.cc +++ b/onnxruntime/core/optimizer/graph_transformer.cc @@ -12,6 +12,7 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logg // ORT_RETURN_IF_ERROR(graph.Resolve()); auto status = ApplyImpl(graph, modified, 0, logger); + LOGS(logger, INFO) << "GraphTransformer " << Name() << " modified: " << modified << " with status: " << status; ORT_RETURN_IF_ERROR(status); #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 54511aa02a57c..3d6251a694cfb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,8 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/optimizer/pad_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -75,7 +77,6 @@ #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" -#include "orttraining/core/optimizer/memory_optimizer.h" #endif #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" @@ -127,6 +128,8 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; @@ -189,6 +192,8 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; + AllocatorPtr cpu_allocator = std::make_shared(); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -240,13 +245,14 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - // local CPU allocator is enough as this allocator is finally passed to a local tensor. - // We will also benefit by using a local allocator as we don't need to pass allocator as parameter for EP API refactor - AllocatorPtr cpu_allocator = std::make_shared(); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + const bool enable_quant_qdq_cleanup = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; #if !defined(DISABLE_CONTRIB_OPS) @@ -265,11 +271,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -291,7 +298,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); @@ -346,18 +353,6 @@ InlinedVector> GenerateTransformers( // fusions might be prevented if this one removes a Q/DQ node too early. transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); -#ifdef ENABLE_TRAINING - // Put memory optimization transformer at last (which is done after most of fusions are done) by intention. - // Known issue: after memory optimization is completed, if some fusion happens, it is possible that the - // node priority got changed. This may disorder the execution order of nodes to recompute. - // TODO(pengwa): need to fix this issue. - const std::string enable_memory_optimizer = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); - const std::string probe_level = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0"); - transformers.emplace_back(std::make_unique(enable_memory_optimizer, probe_level)); -#endif - } break; case TransformerLevel::Level3: { @@ -366,16 +361,16 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - AllocatorPtr cpu_allocator = std::make_shared(); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } - // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things - // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So - // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 9cdc0d9ef0473..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -3,22 +3,23 @@ #include "core/optimizer/initializer.h" -#include "core/common/gsl.h" +#include +#include +#include "core/common/gsl.h" #include "core/common/path.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" -#include - namespace onnxruntime { Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, std::string_view name, gsl::span dims) : name_(name), - data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, std::make_shared()) { + data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, + std::make_shared()) { if (!data_.IsDataTypeString()) { memset(data_.MutableDataRaw(), 0, data_.SizeInBytes()); } @@ -39,7 +40,8 @@ Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); // This must be pre-allocated - Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, std::make_shared()); + Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, + std::make_shared()); ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path.ToPathString().c_str(), tensor_proto, w)); data_ = std::move(w); } @@ -289,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -301,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9fe..959fcd6efdc3c 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2d12c407e6e31..4505d4afdf1e0 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -13,27 +13,102 @@ using namespace onnx_transpose_optimization; namespace onnxruntime { namespace layout_transformation { +namespace { +// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. +CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose) { + // we aggressively push the layout transpose nodes. + // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which + // can potentially be worse for performance. Use the cost check in that case. + if (node.OpType() != "Concat" && + (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { + return CostCheckResult::kPushTranspose; + } + + // for other nodes use the default ORT cost check + return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); +} + +#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +const std::unordered_set& GetCUDALayoutSensitiveOps() { + static std::unordered_set cuda_nhwc_ops = []() { + return std::unordered_set{ + "BatchNormalization", + "Conv", + "ConvTranspose", + "GlobalMaxPool", + "MaxPool", + "GlobalAveragePool", + "AveragePool", + }; + }(); + return cuda_nhwc_ops; +} +#endif + +/// +/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the +/// default set of layout sensitive operators if required. +/// +/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we +/// don't hardcode it here. +/// +/// Node to check +/// true if the node should have its layout converted to NHWC. +bool ConvertNodeLayout(const api::NodeRef& node) { + // skip if op is not an ONNX or contrib op + auto domain = node.Domain(); + if (domain != kOnnxDomain && domain != kMSDomain) { + return false; + } + + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + // handle special cases +#if defined(USE_JSEP) + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + if (node.OpType() == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + } +#endif + +#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS + if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + if (layout_sensitive_ops.count(node.OpType())) { + const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps(); + if (!cuda_nhwc_ops.count(node.OpType())) { + return false; + } + } + } +#endif + + return layout_sensitive_ops.count(node.OpType()) != 0; +} +} // namespace // Layout sensitive NCHW ops. TransformLayoutForEP will wrap these with Transpose nodes to convert the input // data to NHWC and output data back to NCHW, and move the op to the internal NHWC domain (kMSInternalNHWCDomain). -// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain. +// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain domain. // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + { + "FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool", + // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default + // as EPs tend to only support one layout. + "Resize", + }; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,45 +117,21 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } -// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. -static CostCheckResult -PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose) { - // we aggressively push the layout transpose nodes. - // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which - // can potentially be worse for performance. Use the cost check in that case. - if (node.OpType() != "Concat" && - (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { - return CostCheckResult::kPushTranspose; - } - - // for other nodes use the default ORT cost check - return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); -} - Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } - - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + if (ConvertNodeLayout(*node)) { // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { @@ -137,7 +188,6 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); modified = true; } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h index 91e21b655f8bd..cfa02c916b73f 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h @@ -20,6 +20,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { // @@region_begin(extended_minimal_build_required_kernels)@@ // kOnnxDomain ops + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 10}, + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 19}, + // OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 21}, pending CPU EP adding support OpIdentifierWithStringViews{kOnnxDomain, "Gather", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 13}, @@ -28,6 +32,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { OpIdentifierWithStringViews{kOnnxDomain, "Identity", 14}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 16}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 19}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 10}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 19}, + // OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 21}, pending CPU EP adding support OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 13}, @@ -39,8 +47,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { #if !defined(DISABLE_CONTRIB_OPS) // kMSDomain ops + OpIdentifierWithStringViews{kMSDomain, "DequantizeLinear", 1}, OpIdentifierWithStringViews{kMSDomain, "NhwcMaxPool", 1}, OpIdentifierWithStringViews{kMSDomain, "QLinearConv", 1}, + OpIdentifierWithStringViews{kMSDomain, "QuantizeLinear", 1}, #endif // !defined(DISABLE_CONTRIB_OPS) // @@region_end(extended_minimal_build_required_kernels)@@ diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..e944522c9c338 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..7a43483cf37d4 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index fc7e694b6a69b..46041bca9dcc1 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -49,19 +49,13 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, InitializedTensorSet::const_iterator it = initialized_tensor_set.find(arg.Name()); if (it != initialized_tensor_set.cend()) { const auto& tensor_proto = *(it->second); - size_t cpu_tensor_length; - ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length)); OrtValue ort_value; - std::unique_ptr data = std::make_unique(cpu_tensor_length); - std::unique_ptr p_tensor; - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue(Env::Default(), - model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), - tensor_proto, - MemBuffer(data.get(), cpu_tensor_length, allocator_ptr_->Info()), - ort_value)); - - initializers_[idx] = ort_value; - buffer_for_initialized_tensors_[idx] = std::move(data); + ORT_RETURN_IF_ERROR( + utils::TensorProtoToOrtValue(Env::Default(), + model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), + tensor_proto, allocator_ptr_, ort_value)); + + initializers_[idx] = std::move(ort_value); } return Status::OK(); @@ -72,7 +66,6 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, ort_value_name_idx_map_.Reserve(num_inputs_outputs); ort_value_idx_nodearg_map_.reserve(num_inputs_outputs); initializers_.reserve(initialized_tensor_set.size()); - buffer_for_initialized_tensors_.reserve(initialized_tensor_set.size()); for (auto* node : nodes) { ORT_THROW_IF_ERROR(onnxruntime::Node::ForEachWithIndex(node->InputDefs(), initialize_maps)); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index e1b8a91545f64..13cf9e652c404 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -70,7 +70,6 @@ class OptimizerExecutionFrame final : public IExecutionFrame { OrtValueNameIdxMap ort_value_name_idx_map_; std::unordered_map ort_value_idx_nodearg_map_; std::unordered_map initializers_; - InlinedHashMap> buffer_for_initialized_tensors_; std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc new file mode 100644 index 0000000000000..b25e7618802dd --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/pad_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +/* + * It matches following pattern: + * Pad + * | + * Conv/MaxPool + */ +bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + // if Pad has input axis, don't fuse it. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) || + node.GetOutputEdgesCount() != 1 || + node.InputDefs().size() > 3) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + const Node& child_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { + return false; + } + + // Don't fuse if MaxPool has optional output indices tensor because output indices tensor + // does not incorporate pad values. Basically if we allow the fusion, then dimension values + // of input tensor < dimension values of input tensor without fusion. + // This will cause the range of values for output indices tensor to be less than what it + // should have been. + + if (child_node.OutputDefs().size() > 1) { + return false; + } + + // conv or maxpool node must use explicit padding to perform this fusion. + if (child_node.GetAttributes().find("auto_pad") != child_node.GetAttributes().end() && + child_node.GetAttributes().at("auto_pad").s() != "NOTSET") { + return false; + } + + const NodeAttributes& pad_attributes = node.GetAttributes(); + if (pad_attributes.find("mode") != pad_attributes.end() && + pad_attributes.at("mode").s() != "constant") { + return false; + } + + // Since opset 11, and moved to inputs. + // Both of these should be initializer because we have to verify the values. + if (node.SinceVersion() >= 11) { + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + (node.InputDefs().size() > 2 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2]))) { + return false; + } + + // constant_value should be zero because Conv and MaxPool allow only 0 as padding value. + if (node.InputDefs().size() > 2) { + const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); + Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()}; + if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) { + return false; + } + } + } else { + if (pad_attributes.find("value") != pad_attributes.end() && + pad_attributes.at("value").f() != 0.0) { + return false; + } + } + + return true; +} + +/* + * - For 1st two dimension Pads array's value should be zero and for rest of them values should >= 0 + */ +Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + std::vector pads_values; + + if (pad_node.SinceVersion() >= 11) { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, pad_node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + pads_values.assign(pads.DataAsSpan().begin(), pads.DataAsSpan().end()); + } else { + pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); + } + + assert(static_cast(pads_values.size()) == (2 * static_cast(pad_node.InputDefs()[0]->Shape()->dim_size()))); + + uint32_t pads_size = static_cast(pads_values.size()); + // check if padding is applied only on feature dims + if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 || + pads_values[pads_size / 2 + 1] != 0) { + return Status::OK(); + } + + // check if padding is only positive + if (std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value < 0; })) { + return Status::OK(); + } + + Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index()); + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); + uint32_t child_pads_size = static_cast(child_pads->size()); + + for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) { + child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]); + uint32_t mirrored_child_index = child_index + (child_pads_size / 2); + uint32_t mirrored_pad_index = pads_index + (pads_size / 2); + child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]); + } + + graph_utils::RemoveNodeOutputEdges(graph, pad_node); + graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]); + graph.RemoveNode(pad_node.Index()); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h new file mode 100644 index 0000000000000..a1b6978a83d1e --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a Pad operator to it's child + * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * is true. + */ +class PadFusion : public RewriteRule { + public: + PadFusion() : RewriteRule("Pad_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Pad"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f42766267b0f9..3d2a81ce7f8cd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -87,12 +87,19 @@ std::vector WhereMoves() { MoveAll(q, ArgType::kOutput)}; return moves; } -QDQReplaceWithNew SplitReplacer() { +QDQReplaceWithNew SplitReplacer(bool has_split_as_input) { NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; - std::vector moves{ - MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput), - MoveAll(q, ArgType::kOutput)}; + std::vector moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)}; + + if (has_split_as_input) { + // Move the optional split input to the new node. + moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true)); + } + + moves.push_back(MoveAll(q, ArgType::kOutput)); + return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves)); } @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear() } Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - return SplitReplacer().Run(graph, selected_nodes); + const auto& target_node = selected_nodes.Target(); + const auto& input_defs = target_node.InputDefs(); + + // The 'split' attribute became an optional input at opset 13. + bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2; + return SplitReplacer(has_split_as_input).Run(graph, selected_nodes); } Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 0e383c3031ca6..29178fe87f75c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { const std::string action_name{"dropSplitQDQ"}; std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Split", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 5015e48fdb7b8..15b501c667046 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder builder.num_input_defs = 1; // set to 1 as the first input is variadic } -void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { +bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + const Node& dq_node = *dq_nodes.front(); + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. + for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { + const Node& q_node = *q_nodes[q_idx]; + + if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + return false; + } + + if (req_equal_quant_params_ && + !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + + return true; +} + +void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.num_output_defs = 1; // set to 1 as the first output is variadic } @@ -443,7 +475,6 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_bias = 0; bool has_bias = false; // bias is optional for LayerNorm @@ -453,9 +484,9 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Input, output, and scale need to be the same type. The bias is int32. + // Input, output, need to be the same type. The bias is int32. + // Scale can be different with input for a16w8 case return (dt_input == dt_output) && - (dt_input == dt_scale) && (has_bias ? dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32 : true); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index be7f7e0288eda..d0d7fb2c2af17 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -115,6 +115,24 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; }; +// DQ node -> Split -> multiple Q nodes with equal quantization types. +// Optionally, the selector can require all input and output quantization parameters to be +// equal and constant. +class SplitNodeGroupSelector : public NodeGroupSelector { + public: + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) + : req_equal_quant_params_(req_equal_quant_params) {} + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + + bool req_equal_quant_params_; // If true, only selects a node group if the input and output + // quantization parameters are all equal/constant, which enables the + // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. +}; + // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: @@ -288,10 +306,11 @@ class InputVariadicSelector : public BaseSelector { void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// DQ -> node -> Variadic Q nodes -class OutputVariadicSelector : public BaseSelector { +// DQ -> Split -> variadic Q nodes +class SplitSelector : public BaseSelector { public: - OutputVariadicSelector() : BaseSelector(std::make_unique()) {} + SplitSelector(bool req_equal_quant_params = false) + : BaseSelector(std::make_unique(req_equal_quant_params)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 3f1b2f0458bc0..544fe82a268c8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -27,14 +27,17 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops } /* static methods to return different operator's OpVersionMap */ + +// These are operators that do not change the data and therefore the input DQ and +// output Q have the same scale and zero_point. static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, + {"Expand", {}}, {"Flatten", {}}, {"Transpose", {}}, {"MaxPool", {12}}, {"Resize", {}}, - {"Split", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}, {"Tile", {}}}; @@ -80,7 +83,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Neg", {}}, {"DepthToSpace", {}}, {"SpaceToDepth", {}}, - {"Clip", {}}}; + {"Clip", {}}, + {"LpNormalization", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -96,6 +100,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { {"Max", {}}, {"Min", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetSplitOpVersionsMap() { + return {{"Split", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } @@ -169,6 +176,13 @@ void RegisterVariadicSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterSplitSelector(Selectors& qdq_selectors) { + /* register selectors for Split op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetSplitOpVersionsMap(), + std::move(selector)); +} + void RegisterConvSelector(Selectors& qdq_selectors) { /* register selector for conv op */ std::unique_ptr selector = std::make_unique(); @@ -246,6 +260,7 @@ void SelectorManager::CreateSelectors() { RegisterUnarySelectors(qdq_selectors_); RegisterBinarySelectors(qdq_selectors_); RegisterVariadicSelectors(qdq_selectors_); + RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2f..546d52b6f1682 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f1..5caa949ebbe93 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 07f391f2ae430..0d7ab70eba613 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "transformer_memcpy.h" +#include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" #include "core/framework/utils.h" @@ -16,12 +17,12 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries); + bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); private: void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); - void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input); + void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); private: @@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { +common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { for (auto& provider : provider_types_) { if (!utils::ProviderIsCpuBased(provider)) { TransformerMemcpyImpl copy_impl(graph, provider); - auto current_modified = copy_impl.ModifyGraph(registry_manager_); + + int copy_node_counter = 0; + auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter); + if (copy_node_counter > 0 && provider == kCudaExecutionProvider) { + LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name() + << " for " << provider + << ". It might have negative impact on performance (including unable to run CUDA graph). " + << "Set session_options.log_severity_level=1 to see the detail logs before this message."; + } + modified = modified || current_modified; break; } @@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different */ -bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) { +bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries, + const logging::Logger& logger, + int& copy_node_counter) { bool modified = false; InitializedTensorSet initializers_consumed; // find defs that require copy @@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // For inputs we need to create a copy node only when the input is connected to both provider // and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job. if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) { - AddCopyNode(const_cast(arg), true); + AddCopyNode(const_cast(arg), true, logger); + copy_node_counter++; modified = true; } for (auto arg : non_provider_output_defs_) if (provider_input_defs_.count(arg)) { - AddCopyNode(arg, true); + AddCopyNode(arg, true, logger); + copy_node_counter++; modified = true; } for (auto arg : provider_output_defs_) if (non_provider_input_defs_.count(arg)) { - AddCopyNode(arg, false); + AddCopyNode(arg, false, logger); + copy_node_counter++; modified = true; } @@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // (the name will be the same as the parent node's implicit input) const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg); - AddCopyNode(const_cast(node_arg_in_current_graph_level), true); + AddCopyNode(const_cast(node_arg_in_current_graph_level), true, logger); + copy_node_counter++; modified = true; } } @@ -232,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg if (!arg->Exists()) continue; - if (kci && kci->kernel_def->IsOutputOnCpu(i)) + if (utils::IsOutputOnCpu(node, kci, i)) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); @@ -291,13 +308,13 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } if (arg_output_index != -1) { - if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it); + if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it); } } } } -void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) { +void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) { // create unique name for new def std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_); @@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input std::string new_node_name = graph_.GenerateNodeName("Memcpy"); const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost"; + LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name() + << " for " << provider_; + auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory", std::vector{src_arg}, std::vector{dst_arg}); @@ -384,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker // normally initializers are only inputs, but things may change with ops like assign ORT_THROW_IF_ERROR(Node::ForEachWithIndex( p_node->OutputDefs(), - [kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { - if (kci->kernel_def->IsOutputOnCpu(index)) { + [kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { + if (utils::IsOutputOnCpu(*p_node, kci, index)) { ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end()); } return Status::OK(); diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 2c11bf144999e..c479b685f9267 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -4,9 +4,11 @@ #include "onnx_transpose_optimization.h" #include +#include #include #include #include +#include #include #include "core/common/gsl.h" @@ -17,6 +19,9 @@ namespace onnx_transpose_optimization { /////// /////// /* Small utilities for editing nodes and manipulating axes/permutations */ +static constexpr bool IsOnnxDomain(std::string_view domain) { + return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias); +} static std::vector DataInt64(api::TensorRef& tensor) { std::vector raw_data = tensor.Data(); @@ -93,6 +98,19 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } +// Use to create a QuantizeLinear or DequantizeLinear node. Does not update output ValueInfo. Adds axis if needed. +static std::unique_ptr MakeQOrDQ(api::GraphRef& graph, std::string_view domain, std::string_view op_type, + std::vector inputs, + std::optional axis) { + std::unique_ptr node = graph.AddNode(op_type, inputs, /* num_outputs */ 1, domain); + // only set if provided and not the default + if (axis && axis != 1) { + node->SetAttributeInt("axis", *axis); + } + + return node; +} + // Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) static bool IsValidPerm(const std::vector& perm) { size_t rank = perm.size(); @@ -117,6 +135,263 @@ static std::optional> GetPermAttrIfValid(const api::NodeRef return perm; } +static inline bool NormalizeAndValidateAxis(int64_t& axis, size_t rank) { + int64_t rank_int = gsl::narrow_cast(rank); + if (axis < 0) { + axis += rank_int; + } + + return axis >= 0 && axis < rank_int; +} + +/// +/// Check if an output value has a single consumer that is a node. +/// +/// Consumer node if found. +/// True if there is a single consumer node. +static bool OutputValueHasSingleConsumerNode(const api::GraphRef& graph, const api::NodeRef& node, size_t output_idx, + std::unique_ptr& single_consumer) { + auto value = node.Outputs()[output_idx]; + auto consumers = graph.GetValueConsumers(value); + + if (consumers->comprehensive && (consumers->nodes.size() == 1)) { + single_consumer = std::move(consumers->nodes[0]); + } else { + single_consumer.reset(); + } + + return single_consumer != nullptr; +} + +/// return the DQ node if value_name is produced by a DQ node +static std::unique_ptr GetDQIfProducingValue(const api::GraphRef& graph, std::string_view value_name) { + auto maybe_dq_node = graph.GetNodeProducingOutput(value_name); + + return (maybe_dq_node != nullptr && maybe_dq_node->OpType() == "DequantizeLinear") ? std::move(maybe_dq_node) + : std::unique_ptr(); +} + +/// +/// Return a DequantizeLinear node if it's input is a constant initializer and it has a single consumer. +/// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. +/// +/// Current graph +/// Value to check if produced by a DQ node who's input is a constant initializer +/// NodeRef for DQ node if it meets the requirements. +static std::unique_ptr GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph, + std::string_view value_name) { + std::unique_ptr result; + auto dq_node = GetDQIfProducingValue(graph, value_name); + + if (dq_node) { + do { + auto dq_input = dq_node->Inputs()[0]; + auto dq_constant = graph.GetConstant(dq_input); + + // input to DQ must be a constant initializer + if (!dq_constant) { + break; + } + + // For now keep it simple and don't support per-axis quantization as that would require updating the axis of + // the DQ node during TransposeInputImpl and UnsqueezeInput. + auto dq_scale = graph.GetConstant(dq_node->Inputs()[1]); + if (!dq_scale || dq_scale->NumElements() != 1) { + break; + } + + // need to know all the initializer consumers as we're potentially going to modify it directly + auto initializer_consumers = graph.GetValueConsumers(dq_input); + if (!initializer_consumers->comprehensive) { + break; + } + + std::unique_ptr consumer; + if (!OutputValueHasSingleConsumerNode(graph, *dq_node, 0, consumer)) { + break; + } + + result = std::move(dq_node); + } while (false); + } + + return result; +} + +/// +/// Insert a Q -> DQ pair after the node following the DQ by using scale and zp info from the preceding DQ node. +/// DQ -> next node => DQ -> next node -> Q -> DQ. +/// This is only called for Transpose and Unsqueeze nodes. +/// +/// DQ node. +/// Node following DQ node. +/// New DQ node at end of DQ -> next_node -> Q -> DQ. +/// True if insert was successful. +static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { + std::unique_ptr single_consumer_node; + if (!OutputValueHasSingleConsumerNode(graph, dq_node, 0, single_consumer_node)) { + // should never happen as caller should have checked previously + return false; + } + + auto& next_node = *single_consumer_node; + assert(next_node.OpType() == "Transpose" || next_node.OpType() == "Unsqueeze"); + + const auto dq_domain = dq_node.Domain(); + const auto& dq_inputs = dq_node.Inputs(); + const bool is_transpose = next_node.OpType() == "Transpose"; + + const auto scale_input = dq_inputs[1]; + const auto scale_value_info = graph.GetValueInfo(scale_input); + std::optional zp_input; + std::optional> zp_value_info; + + auto scale_shape = scale_value_info->Shape(); + if (!scale_shape && is_transpose) { + // axis potentially needs updating due to the transpose but we don't have the required info to do it. + return false; + } + + if (dq_inputs.size() > 2) { + zp_input = dq_inputs[2]; + zp_value_info = graph.GetValueInfo(zp_input.value()); + } + + // per-axis quantization if not a scalar (shape is empty for scalar). + // note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, + // so we have to check the shape. + auto update_dq_axis = scale_shape && !scale_shape->empty(); + int64_t axis = dq_node.GetAttributeIntDefault("axis", 1); + + if (update_dq_axis && is_transpose) { + // update axis. + auto perm = GetPermAttrIfValid(next_node); + assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid + NormalizeAndValidateAxis(axis, scale_shape->size()); + axis = InvertPerm(*perm)[gsl::narrow_cast(axis)]; + } + + auto next_node_output_name = next_node.Outputs()[0]; + auto next_node_output_shape = graph.GetValueInfo(next_node_output_name)->Shape(); + + // setup Q node inputs. we don't connect it to next_node yet as we will move the output of that to the new DQ first. + std::vector inputs = {"", scale_input}; + if (zp_input) { + inputs.push_back(zp_input.value()); + } + + // Add Q + auto new_q_node = MakeQOrDQ(graph, dq_domain, "QuantizeLinear", inputs, axis); + auto q_node_outputs = new_q_node->Outputs(); + + // copy value info from the dq input for the type information, and update the shape to match next_node's output + graph.CopyValueInfo(dq_node.Inputs()[0], q_node_outputs[0]); // Q produces same type as the dq_node input + auto q_node_value_info = graph.GetValueInfo(q_node_outputs[0]); + q_node_value_info->SetShape(next_node_output_shape ? &*next_node_output_shape : nullptr); + + // update input to connect the DQ to the Q we just added. re-use scale and zp. + inputs[0] = new_q_node->Outputs()[0]; + + // Add DQ + auto new_dq_node = MakeQOrDQ(graph, dq_domain, "DequantizeLinear", inputs, axis); + auto dq_node_outputs = new_dq_node->Outputs(); + + // straight copy of value info as the type and shape are the same as next_node's output + graph.CopyValueInfo(next_node_output_name, dq_node_outputs[0]); + + // move next_node output to the new DQ node in case it was a graph output, and connect next_node with the new Q node + graph.MoveOutput(next_node, 0, *new_dq_node, 0); + auto new_next_node_output_name = next_node.Outputs()[0]; + new_q_node->SetInput(0, new_next_node_output_name); + graph.CopyValueInfo(dq_node_outputs[0], new_next_node_output_name); + + return true; +} + +/// +/// Check if a DQ -> Q pair have matching type/scale/zero point. +/// If there's no operator between them, and they match, they are redundant and can be removed. +/// +/// True if they match. +static bool CheckQDQNodePairMatch(const api::GraphRef& graph, + const api::NodeRef& dq_node, const api::NodeRef& q_node) { + bool match = false; + + do { + if (dq_node.Domain() != q_node.Domain()) { + break; + } + + auto t1 = graph.GetValueInfo(dq_node.Inputs()[0])->DType(); + auto t2 = graph.GetValueInfo(q_node.Outputs()[0])->DType(); + + if (t1 == api::DataType::UNDEFINED || t2 == api::DataType::UNDEFINED || t1 != t2) { + break; + } + + auto dq_scale = dq_node.Inputs()[1]; + auto q_scale = q_node.Inputs()[1]; + + if (dq_scale != q_scale) { + auto dq_scale_value = graph.GetConstant(dq_scale); + auto q_scale_value = graph.GetConstant(q_scale); + if (!dq_scale_value || !q_scale_value) { + break; // non-const input + } + + if (dq_scale_value->Data() != q_scale_value->Data()) { + break; + } + } + + auto dq_zp = dq_node.Inputs().size() > 2 ? dq_node.Inputs()[2] : ""; + auto q_zp = q_node.Inputs().size() > 2 ? q_node.Inputs()[2] : ""; + + if (dq_zp != q_zp) { + std::optional> dq_scale_value; + std::optional> q_scale_value; + if (dq_zp != "") { + dq_scale_value = graph.GetConstant(dq_zp); + if (!dq_scale_value.value()) { + break; // non-const input + } + } + + if (q_zp != "") { + q_scale_value = graph.GetConstant(q_zp); + if (!q_scale_value.value()) { + break; // non-const input + } + } + + if (dq_scale_value.has_value() && q_scale_value.has_value()) { + if (dq_scale_value->get()->Data() != q_scale_value->get()->Data()) { + break; + } + } else { + // check the input with a value matches the default zp value of 0 + if (dq_scale_value.has_value()) { + auto data = dq_scale_value->get()->Data(); + if (!std::all_of(data.begin(), data.end(), [](auto value) { return value == 0; })) { + break; + } + } else { + // q_scale_value must have a value to get here + auto data = q_scale_value->get()->Data(); + if (!std::all_of(data.begin(), data.end(), [](auto value) { return value == 0; })) { + break; + } + } + } + } + + match = true; + + } while (false); + + return match; +} + // Adds rank to negative axes and checks that axes are unique and within [0, rank). Returns false if invalid. static bool NormalizeAndValidateAxes(std::vector& axes, size_t rank) { int64_t rank_int = gsl::narrow_cast(rank); @@ -134,15 +409,6 @@ static bool NormalizeAndValidateAxes(std::vector& axes, size_t rank) { return true; } -static inline bool NormalizeAndValidateAxis(int64_t& axis, size_t rank) { - int64_t rank_int = gsl::narrow_cast(rank); - if (axis < 0) { - axis += rank_int; - } - - return axis >= 0 && axis < rank_int; -} - // Read int64 data from attribute or input, depending on whether model opset < provided opset static std::optional> ReadFromAttrOrInput(OptimizerCtx& ctx, api::NodeRef& node, std::string_view attr_name, size_t inp_index, @@ -345,6 +611,12 @@ static std::vector SortedAxesForTransposedInput(const std::vectorShape(); + graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape); +} + /////// /////// /////// /////// @@ -357,51 +629,102 @@ static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector // broadcasting. static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& axes) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); std::unique_ptr constant = ctx.graph.GetLocalConstant(input); - auto consumers = ctx.graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = ctx.graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + auto value_to_modify = dq_node ? constant_dq_input : input; + auto consumers = ctx.graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // We will reshape the initializer. If there are existing consumers, still reshape it but add Squeeze nodes + // We will reshape the initializer. If there are existing consumers, reshape it and add Squeeze nodes // to counteract its effect. If they later Unsqueeze the same input, the Squeeze nodes will simply be deleted // (see Case 2). if (consumers->nodes.size() > 0) { - auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", input, axes); + auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; - ctx.graph.CopyValueInfo(input, sq_out); - ReplaceValueReferences(consumers->nodes, input, sq_out); + ctx.graph.CopyValueInfo(value_to_modify, sq_out); + ReplaceValueReferences(consumers->nodes, value_to_modify, sq_out); } + auto new_shape = UnsqueezeShape(constant->Shape(), axes); - ctx.graph.ReshapeInitializer(input, new_shape); - node.SetInput(i, input); + ctx.graph.ReshapeInitializer(value_to_modify, new_shape); + + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input); + } + + node.SetInput(i, input); // restore the original connection return; } // Case 2: input is a Squeeze node with matching axes std::unique_ptr inp_node = ctx.graph.GetNodeProducingOutput(input); + + // look past a DQ node for a Squeeze to cancel + if (inp_node && inp_node->OpType() == "DequantizeLinear") { + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = ctx.graph.GetNodeProducingOutput(dq_input); + consumers = ctx.graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Squeeze")) { const std::vector& inp_node_inputs = inp_node->Inputs(); std::optional> squeeze_axes = std::nullopt; squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13); if (squeeze_axes != std::nullopt && *squeeze_axes == axes) { + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, inp_node_inputs[0]); + } + // Remove the Squeeze node if possible - if (consumers->comprehensive && consumers->nodes.size() == 0) { + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { ctx.graph.RemoveNode(*inp_node); + if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) { ctx.graph.RemoveInitializer(inp_node_inputs[1]); } } - node.SetInput(i, inp_node_inputs[0]); + return; } // Axes don't match. Fall through to Case 3. } + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + } + // Case 3: Add an Unsqueeze node. auto unsqueeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Unsqueeze", input, axes); api::NodeRef& unsqueeze = *unsqueeze_ptr; @@ -426,6 +749,10 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons } node.SetInput(i, unsq_out); + + if (inp_node != nullptr && inp_node->OpType() == "DequantizeLinear") { + MakeQDQNodeUnit(ctx.graph, *inp_node); + } } static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::TensorRef& constant, @@ -453,83 +780,164 @@ static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::Ten // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. -void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, - const std::vector& perm, const std::vector& perm_inv) { +static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, const std::vector& perm_inv) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); + // Only local constants are editable std::unique_ptr constant = graph.GetLocalConstant(input); - auto consumers = graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + + auto constant_to_modify = dq_node ? constant_dq_input : input; + auto consumers = graph.GetValueConsumers(constant_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // Input is scalar, return early. - if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) { + // we modify the initializer in-place and need to reconnect things up when we're done. this helper will + // do that when it goes out of scope. if we have manually reconnected, input or constant_dq_input is + // set to an empty string. + auto reconnect_nodes = gsl::finally([i, &node, &dq_node, &input, &constant_dq_input] { + if (!input.empty()) { + node.SetInput(i, input); + } + + if (!constant_dq_input.empty()) { + dq_node->SetInput(0, constant_dq_input); + } + }); + + // If there is only one element return early as the transpose won't change the data + if (constant->NumElements() == 1) { return; } + // This is a special case where the constant is 1D with length == perm. - // TODO: TransposeInitializer should be updated to handle this case. + // e.g. it provides a set of values that are relative to the input axes like the `sizes` input for Resize // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if // there are no other consumers. if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { - Permute1DConstant(graph, node, *constant, i, input, perm); + auto& node_to_update = dq_node ? *dq_node : node; + Permute1DConstant(graph, node_to_update, *constant, i, constant_to_modify, perm); + + // unset updated input so reconnect_nodes doesn't change it back + if (dq_node) { + constant_dq_input = ""; + } else { + input = ""; + } + return; } + if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. - auto transpose_inv_ptr = MakeTranspose(graph, input, perm_inv); + auto transpose_inv_ptr = MakeTranspose(graph, constant_to_modify, perm_inv); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - ReplaceValueReferences(consumers->nodes, input, transpose_out); + graph.CopyValueInfo(constant_to_modify, transpose_out); + ReplaceValueReferences(consumers->nodes, constant_to_modify, transpose_out); + } + + graph.TransposeInitializer(constant_to_modify, perm); + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, constant_to_modify); + constant_dq_input = ""; // DQ input was already updated so we don't need reconnect_nodes to handle it } - graph.TransposeInitializer(input, perm); - node.SetInput(i, input); + return; } // Case 2: input is a Transpose node std::unique_ptr inp_node = graph.GetNodeProducingOutput(input); + + // Look past a DQ for the Transpose + if (inp_node && inp_node->OpType() == "DequantizeLinear") { + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = graph.GetNodeProducingOutput(dq_input); + consumers = graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Transpose")) { std::optional> perm2 = GetPermAttrIfValid(*inp_node); if (perm2 != std::nullopt && perm2->size() == perm.size()) { // If they cancel, use pre_transpose_value and remove Transpose if possible. if (*perm2 == perm_inv) { std::string_view pre_transpose_value = inp_node->Inputs()[0]; - if (consumers->comprehensive && consumers->nodes.size() == 0) { + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, pre_transpose_value); + } + + // Remove the Transpose node if possible + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { graph.RemoveNode(*inp_node); } - node.SetInput(i, pre_transpose_value); - return; - } else if (*perm2 == perm) { - // we are trying to add a duplicate transpose. - // do nothing and return + return; } - // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove - // the other Transpose. - const std::vector& perm_combined = ComposePerm(*perm2, perm); - auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); - api::NodeRef& transpose = *transpose_ptr; - std::string_view transpose_out = transpose.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - graph.GetValueInfo(transpose_out)->PermuteDims(perm); - if (consumers->comprehensive && consumers->nodes.size() == 0) { - graph.RemoveNode(*inp_node); + if (!dq_node) { + // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove + // the other Transpose. + const std::vector& perm_combined = ComposePerm(*perm2, perm); + auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); + api::NodeRef& transpose = *transpose_ptr; + std::string_view transpose_out = transpose.Outputs()[0]; + graph.CopyValueInfo(input, transpose_out); + graph.GetValueInfo(transpose_out)->PermuteDims(perm); + + if (consumers->comprehensive && consumers->nodes.size() == 0) { + graph.RemoveNode(*inp_node); + } + + node.SetInput(i, transpose_out); + + return; + } else { + // fall through to regular processing if the Transpose prior to the DQ doesn't cancel out cleanly } - node.SetInput(i, transpose_out); - return; } } - // Case 3: A Transpose op might already exist - for (size_t j = 0; j < consumers->nodes.size(); ++j) { - api::NodeRef& consumer = *consumers->nodes[j]; - if (consumer.IsOp("Transpose") && GetPermAttrIfValid(consumer) == perm) { - node.SetInput(i, consumer.Outputs()[0]); + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + consumers = graph.GetValueConsumers(input); + } + + // Case 3: A Transpose op with the same perms might already exist + for (auto& consumer : consumers->nodes) { + if (consumer->IsOp("Transpose") && GetPermAttrIfValid(*consumer) == perm) { + node.SetInput(i, consumer->Outputs()[0]); return; } } @@ -540,11 +948,29 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); graph.GetValueInfo(transpose_out)->PermuteDims(perm); + node.SetInput(i, transpose_out); + + if (inp_node && inp_node->OpType() == "DequantizeLinear") { + MakeQDQNodeUnit(graph, *inp_node); + } +} + +// this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. +// there's no OptimizerCtx in that scenario +void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, + const std::vector& perm_inv) { + TransposeInputImpl(graph, node, i, perm, perm_inv); +} + +static void TransposeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { + TransposeInputImpl(ctx.graph, node, i, perm, perm_inv); } // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. -static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, +static bool NormalizeInputRanks(OptimizerCtx& ctx, api::NodeRef& node, size_t target_rank, const std::vector& input_indices) { auto inputs = node.Inputs(); @@ -579,7 +1005,7 @@ void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& input_indices) { auto perm_inv = InvertPerm(perm); for (size_t j : input_indices) { - TransposeInput(ctx.graph, node, j, perm, perm_inv); + TransposeInput(ctx, node, j, perm, perm_inv); } } @@ -670,30 +1096,95 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t return true; } +// return true if +// - the value is a constant initializer +// - the value is the output of a DQ node who's input is a constant initializer +// - UnsqueezeInput/TransposeInput can look past the DQ to update the constant initializer directly +// - DQ node is currently ignored if it uses per-channel quantization +// - supporting per-channel quantization requires modifying the scales and zero point data, which can be done +// if/when there's a use-case to justify the development cost. +// - the input was originally connected to a shared constant initializer that was updated in place by UnsqueezeInput +// or TransposeInput, and usage by this node had Squeeze/Transpose nodes inserted to counteract the effect of the +// in-place update. if we push the same transpose through this node it should cancel out that Squeeze/Transpose +// +// in all these cases we expect pushing the transpose through to not require a runtime Transpose node +static bool IsConstant(const api::GraphRef& graph, std::string_view value_name) { + std::unique_ptr producer_node = graph.GetNodeProducingOutput(value_name); + + if (!producer_node) { + // initializer or graph input. + // initializer may or may not be constant depending on whether it has a matching graph input + std::unique_ptr constant = graph.GetConstant(value_name); + return constant != nullptr; + } + + // look past a DQ node + if (producer_node->OpType() == "DequantizeLinear") { + std::unique_ptr dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name); + if (dq_node != nullptr) { + // DQ node pointing to an constant initializer + return true; + } + } + + return false; +} + // Estimates the cost of transposing an input. Currently uses rank heuristic. Negative if transpose is removed. // Feel free to improve as needed. static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_view input, - const std::vector& perm_inv, - const HandlerMap& extended_handlers) { + const std::vector& perm_inv, const HandlerMap& extended_handlers) { // Case 1: Transposing constants probably costs nothing. - std::unique_ptr constant = graph.GetConstant(input); - if (constant != nullptr) { + if (IsConstant(graph, input)) { return 0; } // Case 2: Transposing a transpose either cancels it or composes the permutations. - std::unique_ptr node = graph.GetNodeProducingOutput(input); - if (node != nullptr && node->IsOp("Transpose")) { - std::optional> perm2 = GetPermAttrIfValid(*node); - if (perm2 != std::nullopt) { - if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *node, extended_handlers)) { - return -EstimateValueRank(graph, input); - } else { - return 0; + std::unique_ptr producer_node = graph.GetNodeProducingOutput(input); + + if (producer_node != nullptr) { + // this handles cancelling out a Transpose or Squeeze added to a shared initializer that was updated + // by TransposeInputImpl Case 1 or UnqueezeInput Case 1. + // - if a shared initializer is not broadcast, we have -> Transpose -> DQ + // - if a shared initializer is broadcast, we have -> Transpose -> Squeeze -> DQ and need + // to look slightly further in the hopes of finding the Transpose. + // - in practice it's only necessary if the operator that we're looking to push the transpose through has + // more than 2 inputs, and at least one of them is broadcastable. When there are 2 inputs the input with + // the Transpose will have a negative weight. If we don't look past DQ -> Squeeze to find the Transpose + // on the other input the positive weight of the broadcast initializer will always be less as it's based on + // rank, so the total cost estimate will always be negative and we'll push the Transpose. + // onnx::Where may be the only operator that requires the look past Squeeze. + // + // look past a DQ as we do that in the TransposeInput/UnsqueezeInput handling. + // match onnx and contrib ops domain for Q/DQ while we have those ops in both domains. + if (producer_node->OpType() == "DequantizeLinear") { + auto dq_input_node = graph.GetNodeProducingOutput(producer_node->Inputs()[0]); + if (dq_input_node != nullptr) { + if (dq_input_node->OpType() == "Squeeze") { + auto squeeze_input_node = graph.GetNodeProducingOutput(dq_input_node->Inputs()[0]); + if (squeeze_input_node->OpType() == "Transpose") { + // we only want to set this if it is a Transpose as otherwise we're invalidating the cost given it is + // rank based and the Squeeze will change that. + producer_node = std::move(squeeze_input_node); + } + } else { + // DQ doesn't change the rank so we don't need to check the OpType of the DQ input + producer_node = std::move(dq_input_node); + } } } - } + if (producer_node->IsOp("Transpose")) { + std::optional> perm2 = GetPermAttrIfValid(*producer_node); + if (perm2 != std::nullopt) { + if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *producer_node, extended_handlers)) { + return -EstimateValueRank(graph, input); + } else { + return 0; + } + } + } + } // Case 3: We will likely need to add a transpose. return EstimateValueRank(graph, input); } @@ -708,6 +1199,7 @@ static int EstimateTransposeInputsCost(const api::GraphRef& graph, const api::No for (size_t j : input_indices) { cost += EstimateTransposeValueCost(graph, inputs[j], perm_inv, extended_handlers); } + return cost; } @@ -734,8 +1226,10 @@ static bool HandleSimpleNodeBase(HandlerArgs& args, bool broadcast_inputs) { if (broadcast_inputs && !NormalizeInputRanks(args.ctx, args.node, rank, args.transposible_inputs)) { return false; } + TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs); TransposeOutputs(args.ctx, args.node, args.perm); + return true; } @@ -907,38 +1401,29 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con size_t rank = perm.size(); int64_t rank_int = gsl::narrow_cast(rank); - std::string_view input = node.Inputs()[i]; - auto constant = graph.GetConstant(input); + std::string_view input_name = node.Inputs()[i]; + auto constant = graph.GetConstant(input_name); if (constant != nullptr) { auto shape = constant->Shape(); if (shape.size() == 1 && (shape[0] == rank_int || shape[0] == 0)) { - Permute1DConstant(graph, node, *constant, i, input, perm); + Permute1DConstant(graph, node, *constant, i, input_name, perm); return; } } + // we don't check for a DQ input here as PermuteInput is only used for Resize (roi/scales/sizes) and Pad (pads) + // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); - std::vector gather_inputs{input, gather_indices_const}; + std::vector gather_inputs{input_name, gather_indices_const}; auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; - graph.CopyValueInfo(input, gather_output); + graph.CopyValueInfo(input_name, gather_output); gather.SetAttributeInt("axis", 0); node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -964,10 +1449,10 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } -constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; +// Not currently registered by default. +// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; static bool HandlePad(HandlerArgs& args) { size_t rank = args.perm.size(); @@ -1719,8 +2204,11 @@ static const std::unordered_map handler_ma {"Split", split_handler}, {"Shape", shape_handler}, {"Pad", pad_handler}, - {"Resize", resize_handler}, - {"ReduceSum", reduce_op_handler}, + + // Execution providers tend to only implement Resize for specific layouts. Due to that, it's safer to not + // push a Transpose through a Resize unless the EP specifically checks that it can handle the change via an + // extended handler. + // {"Resize", resize_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, @@ -1728,6 +2216,7 @@ static const std::unordered_map handler_ma {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, @@ -1749,14 +2238,6 @@ static const std::unordered_map handler_ma {"Reshape", reshape_handler}, }; -constexpr bool IsOnnxDomain(std::string_view domain) { - return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias); -} - -constexpr bool IsMSDomain(std::string_view domain) { - return domain == onnxruntime::kMSDomain; -} - static const HandlerInfo* GetHandler(api::NodeRef& node, const HandlerMap& extended_handlers) { std::string key; auto domain = node.Domain(); @@ -1817,12 +2298,12 @@ static int CalculateCost(const api::GraphRef& graph, const api::NodeRef& node, } // Default cost check. Returns `true` if pushing the Transpose through the node is considered to be beneficial. -static bool ShouldPushTranspose(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose, - const HandlerInfo& info, - const std::vector transposable_input_indices, - const HandlerMap& extended_handlers) { +static bool DefaultCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose, + const HandlerInfo& info, + const std::vector transposable_input_indices, + const HandlerMap& extended_handlers) { if (node.IsOp("Transpose")) { return true; } @@ -1854,8 +2335,8 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& } if (cost == CostCheckResult::kFallThrough) { - cost = ShouldPushTranspose(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, - ctx.extended_handlers) + cost = DefaultCostCheck(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, + ctx.extended_handlers) ? CostCheckResult::kPushTranspose : CostCheckResult::kStop; } @@ -2009,75 +2490,99 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { } } } - if (!have_dq) { result.graph_modified = changed; return result; } - // Run second optimization pass. - // If any transpose succeeds a DQ node, move it above the DQ node if it's not part of a QDQ node group. - // In QDQ models this helps to preserve the QDQ node group when a Transpose was pushed across a DQ into - // an existing QDQ node group. - // In all other scenarios this is beneficial as well because moving transpose above DQ node is more efficient as - // transpose node now handles less data. + // Run 'fix up' pass for QDQ node units. + // + // Repair broken QDQ node unit from Transpose being blocked on Op inside a QDQ node unit. + // DQ -> Transpose -> Op -> Q => + // DQ -> Transpose -> Q -> DQ -> Op -> Q + // + // Create QDQ node unit for Transpose after DQ that provides graph output. + // DQ -> Transpose -> graph output => + // DQ -> Transpose -> Q -> DQ -> graph output + // + // Remove empty DQ -> Q pair from moving a Transpose downstream or a Transpose being cancelled out. + // DQ -> Q -> consumer node => + // consumer node + auto graph_nodes = ctx.graph.Nodes(); for (size_t i = 1; i < graph_nodes.size(); i++) { - const auto& node = *graph_nodes[i]; + auto& node = *graph_nodes[i]; if (!can_modify_node(node)) { continue; } - if (node.OpType() == "Transpose") { - auto& transpose_node = *graph_nodes[i]; - auto dq_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]); - if (!dq_node || dq_node->OpType() != "DequantizeLinear") { + for (size_t i_idx = 0, i_end = node.Inputs().size(); i_idx < i_end; ++i_idx) { + // any change requires a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(node.Inputs()[i_idx]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { continue; } - // Check if Transpose node is the only consumer of dq node - auto consumers_of_dq_node = ctx.graph.GetValueConsumers(dq_node->Outputs()[0]); - if (!consumers_of_dq_node->comprehensive || consumers_of_dq_node->nodes.size() > 1) { - continue; - } + auto& dq_node = *input_node; + std::unique_ptr single_consumer_node; + + // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. + if (node.OpType() == "QuantizeLinear") { + // we don't need to check scale and zp inputs, and we may remove nodes invalidating `node` if we + // continue with the loop of inputs so set i_end to bail + i_end = 1; + + auto& q_node = node; + if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && + OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && + CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // connect Q consumer to DQ input + for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { + if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { + single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); + // break; in theory the Q might be providing multiple inputs. + } + } - auto consumers_of_transpose_node = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]); - bool is_part_of_qdq_group = std::find_if(consumers_of_transpose_node->nodes.cbegin(), - consumers_of_transpose_node->nodes.cend(), - [](const std::unique_ptr& node) { - return node->OpType() == "QuantizeLinear"; - }) != consumers_of_transpose_node->nodes.cend(); - if (is_part_of_qdq_group) { - continue; - } + // disconnect other nodes and remove + dq_node.SetInput(0, ""); + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); - // Update Dequantize Node and move the transpose above it - auto perm = GetPermAttrIfValid(transpose_node); - if (!perm.has_value()) { - continue; + changed = true; + continue; + } } - // we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis - // attribute correctly when doing per-axis dequantization - std::string_view dq_domain = dq_node->Domain(); - std::vector perm_inv = InvertPerm(*perm); + // DQ -> Transpose => DQ -> Transpose -> Q -> DQ if needed + if (node.OpType() == "Transpose") { + auto& transpose_node = node; + + // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. + // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. + auto transpose_output = transpose_node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(transpose_output); + if (consumers->nodes.empty()) { + // DQ -> Transpose -> graph output + } else { + if (consumers->nodes.size() > 1) { + // unexpected to have DQ -> Transpose -> multiple consumers + continue; + } - if (IsOnnxDomain(dq_domain) && !HandleQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node, ctx.opset)) { - continue; - } + if (consumers->nodes[0]->OpType() == "QuantizeLinear") { + // already in QDQ node unit + continue; + } + } - if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { - continue; + // Add Q -> DQ after the DQ -> Transpose + if (MakeQDQNodeUnit(ctx.graph, dq_node)) { + changed = true; + } } - - TransposeFirstInput(ctx, *dq_node, *perm); - - // remove existing transpose node - transpose_node.SetInput(0, ""); - ctx.graph.MoveOutput(transpose_node, 0, *dq_node, 0); - ctx.graph.RemoveNode(transpose_node); - changed = true; } } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 1a54e7834a4ae..6d1f1f8535ba4 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -3,6 +3,7 @@ #pragma once +#include #include // implementation details of the transpose optimizer API defined in optimizer_api.h. @@ -38,6 +39,8 @@ struct HandlerInfo { bool transposes_outputs = true; }; +using NodeIdToInputIdxsMap = std::unordered_map>; + struct OptimizerCtx { int64_t opset; api::GraphRef& graph; @@ -69,6 +72,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 40a03f24f7648..c45aaef0cf02f 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -242,6 +242,12 @@ class NodeRef { /// since version or default value -1 virtual int SinceVersion() const = 0; + /// + /// Get the unique id of the node. + /// + /// Id + virtual int64_t Id() const = 0; + virtual ~NodeRef(){}; }; @@ -436,13 +442,20 @@ class GraphRef { return !unused; } + /// + /// Is the value a graph output. + /// + /// Value name. + /// True if output of the Graph. + virtual bool IsGraphOutput(std::string_view name) const = 0; + virtual ~GraphRef(){}; }; } // namespace api constexpr int64_t kMinSupportedOpset = 7; -constexpr int64_t kMaxSupportedOpset = 19; +constexpr int64_t kMaxSupportedOpset = 20; // enum of results that a CostCheckFn can return. enum class CostCheckResult { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index b30c94d7b3e40..d9f08ffe1171e 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -95,7 +95,8 @@ class ApiNode final : public api::NodeRef { void ClearAttribute(std::string_view name) override; void SetInput(size_t i, std::string_view name) override; std::string_view GetExecutionProviderType() const override; - virtual int SinceVersion() const override; + int SinceVersion() const override; + int64_t Id() const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); @@ -106,10 +107,17 @@ class ApiGraph final : public api::GraphRef { onnxruntime::Graph& graph_; AllocatorPtr cpu_allocator_; const char* new_node_ep_; + std::unordered_set graph_outputs_; // graph_.GetOutputs() names for efficient lookup public: explicit ApiGraph(onnxruntime::Graph& graph, AllocatorPtr cpu_allocator, const char* new_node_ep) - : graph_(graph), cpu_allocator_(std::move(cpu_allocator)), new_node_ep_(new_node_ep) {} + : graph_(graph), cpu_allocator_(std::move(cpu_allocator)), new_node_ep_(new_node_ep) { + const auto& graph_outputs = graph_.GetOutputs(); + graph_outputs_.reserve(graph_outputs.size()); + for (const auto* output : graph_outputs) { + graph_outputs_.insert(output->Name()); + } + } onnxruntime::Graph& Graph() { return graph_; @@ -137,6 +145,7 @@ class ApiGraph final : public api::GraphRef { void MoveOutput(api::NodeRef& src_node, size_t src_idx, api::NodeRef& dst_node, size_t dst_idx) override; void CopyValueInfo(std::string_view src_name, std::string_view dst_name) override; bool HasValueConsumers(std::string_view name) const override; + bool IsGraphOutput(std::string_view name) const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiGraph); @@ -417,6 +426,10 @@ int ApiNode::SinceVersion() const { return node_.SinceVersion(); } +int64_t ApiNode::Id() const { + return node_.Index(); +} + // std::optional ApiGraph::Opset(std::string_view domain) const { @@ -442,6 +455,10 @@ std::vector> ApiGraph::Nodes() const { return nodes; } +bool ApiGraph::IsGraphOutput(std::string_view name) const { + return graph_outputs_.find(name) != graph_outputs_.end(); +} + std::unique_ptr ApiGraph::GetConstant(std::string_view name) const { const auto* tensor = graph_.GetConstantInitializer(std::string(name), /*check_outer_scope*/ true); if (tensor == nullptr) { @@ -489,11 +506,8 @@ std::unique_ptr ApiGraph::GetValueConsumers(std::string_vie } } - const auto& graph_outputs = graph_.GetOutputs(); - for (const auto* output : graph_outputs) { - if (output->Name() == name) { - consumers->comprehensive = false; - } + if (IsGraphOutput(name)) { + consumers->comprehensive = false; } return consumers; @@ -505,14 +519,7 @@ bool ApiGraph::HasValueConsumers(std::string_view name) const { return true; } - const auto& graph_outputs = graph_.GetOutputs(); - for (const auto* output : graph_outputs) { - if (output->Name() == name) { - return true; - } - } - - return false; + return IsGraphOutput(name); } std::unique_ptr ApiGraph::GetNodeProducingOutput(std::string_view name) const { @@ -699,10 +706,6 @@ static std::optional GetLayoutTransformationPotentiallyAddedOpSinceVersion( // Based on the opset version imported for this model, returns the since version for the node. static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain, const std::unordered_map& domain_to_version_map) { - // TODO do we need this check? we will also check kLayoutTransformationPotentiallyAddedOps - ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ", - domain, " provided for op: ", op_type); - const auto opset_import_iter = domain_to_version_map.find(std::string(domain)); ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), domain, " domain not found in opset imports."); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index ead82a6b56741..8eaac3d34c3af 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,35 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, execution providers typically implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled + // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kCpuExecutionProvider) { + // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model + int64_t rank_int = gsl::narrow_cast(args.perm.size()); + if (rank_int == 4) { + static const std::vector nchw_to_nhwc_perm{0, 2, 3, 1}; + static const std::vector nhwc_to_nchw_perm{0, 3, 1, 2}; + if (args.perm == nchw_to_nhwc_perm || args.perm == nhwc_to_nchw_perm) { + return HandleResize(args); + } + } + } + + return false; +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -62,7 +85,7 @@ static bool HandleMaxPool(HandlerArgs& args) { ORT_UNUSED_PARAMETER(args); return false; #else - if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + if (args.node.GetExecutionProviderType() != kCpuExecutionProvider) { return false; } @@ -103,6 +126,7 @@ static bool HandleContribQuantizeDequantizeLinear(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, @@ -113,6 +137,7 @@ const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, + {"Resize", ep_aware_resize_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..8245d8c3b4eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -10,7 +10,6 @@ namespace onnxruntime { /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. -/// Extends the handlers returned by OrtHandlers. /// /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 33e3f5eeaf0fa..092df9cc7dcfb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -18,10 +18,18 @@ namespace onnxruntime { Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, - OrtExtendedHandlers()); + OptimizeResult result; + + if (ep_.empty()) { + // basic usage - no EP specific optimizations + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); + result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtExtendedHandlers()); + } else { + // EP specific optimizations enabled. Currently only used for CPU EP. + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ ep_.c_str()); + result = onnx_transpose_optimization::Optimize(*api_graph, ep_, OrtEPCostCheck, OrtExtendedHandlers()); + } if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer.h index 1ae6d611d2f0e..97d7ab4d0e220 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer.h @@ -15,10 +15,14 @@ Push transposes through ops and eliminate them. class TransposeOptimizer : public GraphTransformer { private: AllocatorPtr cpu_allocator_; + const std::string ep_; public: - explicit TransposeOptimizer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} + explicit TransposeOptimizer(AllocatorPtr cpu_allocator, + const std::string& ep = {}) noexcept + : GraphTransformer(ep.empty() ? "TransposeOptimizer" : "TransposeOptimizer_" + ep), + cpu_allocator_(std::move(cpu_allocator)), + ep_{ep} {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index b08d189f79866..ff6a059607367 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -12,7 +12,7 @@ // #ifndef NDEBUG #ifdef ONNXRUNTIME_ENABLE_MEMLEAK_CHECK -constexpr int c_callstack_limit = 16; // Maximum depth of callstack in leak trace +constexpr int c_callstack_limit = 32; // Maximum depth of callstack in leak trace #define VALIDATE_HEAP_EVERY_ALLOC 0 // Call HeapValidate on every new/delete #pragma warning(disable : 4073) // initializers put in library initialization area (this is intentional) @@ -223,6 +223,11 @@ Memory_LeakCheck::~Memory_LeakCheck() { // empty_group_names = new std::map; }); if (string.find("RtlRunOnceExecuteOnce") == std::string::npos && string.find("re2::RE2::Init") == std::string::npos && + string.find("dynamic initializer for 'FLAGS_") == std::string::npos && + string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("AbslFlagDefaultGenForundefok::Gen") == std::string::npos && + string.find("::SetProgramUsageMessage") == std::string::npos && + string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index f02c61daabeed..45648010baf86 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,7 +32,7 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" -#include "unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h" +#include #include #include "core/platform/path_lib.h" // for LoopDir() diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index cac6f4f29043b..d7d423e4a483e 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -10,7 +10,6 @@ #include #endif #endif -#include #include "core/common/logging/logging.h" #include "core/common/gsl.h" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b142db86a7902..b4132d3b770ec 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -51,7 +51,7 @@ class BaseOpBuilder : public IOpBuilder { virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 19; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc new file mode 100644 index 0000000000000..c454a2a779f6e --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/builders/impl/base_op_builder.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" + +#ifdef __APPLE__ +#include "core/providers/coreml/builders/model_builder.h" +#endif +#include "core/providers/coreml/builders/op_builder_factory.h" + +namespace onnxruntime { +namespace coreml { + +class SoftmaxOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; + +// Add operator related + +#ifdef __APPLE__ + +Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + std::unique_ptr layer = CreateNNLayer(model_builder, node); + const auto& input_name = node.InputDefs()[0]->Name(); + const auto& output_name = node.OutputDefs()[0]->Name(); + + std::vector data_shape; + ORT_RETURN_IF_NOT(GetStaticShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); + + NodeAttrHelper helper(node); + int32_t axis_default_value = (node.SinceVersion() < 13) ? 1 : -1; + const auto axis = helper.Get("axis", axis_default_value); + const auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); + + if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(axis); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } else { + // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. + // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. + TensorShape input_shape(data_shape); + const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); + const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); + + TensorShapeVector target_shape; + target_shape.push_back(size_to_dimension); + target_shape.push_back(size_from_dimension); + + const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); + { // Add reshape layer + const auto softmax_reshape1_layer_name = + model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); + auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; + *reshape_layer->mutable_input()->Add() = input_name; + *reshape_layer->mutable_output()->Add() = reshape1_output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output")); + { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(-1); + *layer->mutable_input()->Add() = reshape1_output_name; + *layer->mutable_output()->Add() = softmax_output_name; + model_builder.AddLayer(std::move(layer)); + } + { + // Add reshape back layer + const auto softmax_reshape2_layer_name = + model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); + auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; + *reshape_layer->mutable_input()->Add() = softmax_output_name; + *reshape_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + } + + return Status::OK(); +} + +#endif + +// Operator support related + +bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetStaticShape(*input_defs[0], input_shape, logger)) + return false; + + const TensorShape shape(input_shape); + if (shape.Size() == 0) { + LOGS(logger, VERBOSE) << "Empty input data is not supported."; + return false; + } + + return true; +} + +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc new file mode 100644 index 0000000000000..815f68128ffaf --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/builders/impl/base_op_builder.h" + +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" + +#if defined(__APPLE__) +#include "core/providers/coreml/builders/model_builder.h" +#endif + +namespace onnxruntime { +namespace coreml { + +class SplitOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + + // Split opset 13- uses "split" as attribute. Currently it's not supported. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } +}; + +// Add operator related + +#ifdef __APPLE__ + +void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + const auto& input_defs = node.InputDefs(); + + if (input_defs.size() > 1 && input_defs[1]->Exists()) { // optional second input "split" + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + } +} + +Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector data_shape; + ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); + + NodeAttrHelper helper(node); + const auto axis = helper.Get("axis", 0); + + // attribute introduced since opset 18 + uint64_t num_outputs; + + std::unique_ptr layer = CreateNNLayer(model_builder, node); + auto* coreml_splitnd = layer->mutable_splitnd(); + coreml_splitnd->set_axis(axis); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + Initializer unpacked_tensor(split_tensor); + auto split_span = unpacked_tensor.DataAsSpan(); + auto split_sizes = split_span.size(); + num_outputs = narrow(split_sizes); + for (size_t i = 0; i < split_sizes; i++) { + coreml_splitnd->add_splitsizes(split_span[i]); + } + } else if (node.SinceVersion() < 18) { + num_outputs = narrow(node.OutputDefs().size()); + coreml_splitnd->set_numsplits(num_outputs); + } else { + // note: for opset 18+ 'num_outputs' is a required attribute + num_outputs = narrow(helper.GetInt("num_outputs").value()); + // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists + auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; + uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t remainder = split_dim_size % chunk_size; + if (remainder) { + // uneven + auto split_sizes = InlinedVector(num_outputs, chunk_size); + split_sizes.back() = remainder; + for (size_t i = 0; i < split_sizes.size(); i++) { + coreml_splitnd->add_splitsizes(split_sizes[i]); + } + } else { + // even + coreml_splitnd->set_numsplits(num_outputs); + } + } + + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. + // Otherwise, uses attribute value 'num_outputs'. + for (uint64_t i = 0; i < num_outputs; i++) { + *layer->mutable_output()->Add() = node.OutputDefs()[i]->Name(); + } + model_builder.AddLayer(std::move(layer)); + + return Status::OK(); +} + +#endif + +// Operator support related + +bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + + NodeAttrHelper helper(node); + const auto axis = helper.Get("axis", 0); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; + if (input_defs.size() > 1 && input_defs[1]->Exists()) { + if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) { + return false; + } + const auto split_shape = *input_defs[1]->Shape(); + if (split_shape.dim_size() < 2) { + LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs."; + return false; + } + const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); + Initializer unpacked_tensor(splits_tensor); + auto splits_span = unpacked_tensor.DataAsSpan(); + int sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), 0); + if (sum_of_splits != split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Mismatch between the sum of 'split'. Expected: " + << split_dims_at_axis + << "Actual: " + << sum_of_splits; + return false; + } + auto it = std::find(splits_span.begin(), splits_span.end(), 0); + if (it != splits_span.end()) { + LOGS(logger, VERBOSE) << "Invalid value in 'splits' input."; + return false; + } + if (split_dims_at_axis == -1) { + LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic."; + return false; + } + } else { + if (node.SinceVersion() >= 18) { + const auto num_outputs = helper.GetInt("num_outputs"); + if (!num_outputs.has_value()) { + LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (num_outputs.value() < 2) { + LOGS(logger, VERBOSE) << "Invalid num_outputs. The value cannot be lower than 2.\n" + << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); + return false; + } + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." + << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + << num_outputs.value(); + return false; + } + } + } + return true; +} + +void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index c1b09cec8a30a..2c06659852134 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -122,6 +122,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSliceOpBuilder("Slice", op_registrations); } + { // Softmax + CreateSoftmaxOpBuilder("Softmax", op_registrations); + } + + { // Split + CreateSplitOpBuilder("Split", op_registrations); + } + return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index b2c8dc765d33d..d72420bcfff88 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -36,6 +36,8 @@ void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60e0b1c061a43..155201ad4c39c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -31,6 +32,13 @@ using namespace onnxruntime::coreml; namespace { +// Converts a UTF8 const char* to an NSString. Throws on failure. +NSString* _Nonnull Utf8StringToNSString(const char* utf8_str) { + NSString* result = [NSString stringWithUTF8String:utf8_str]; + ORT_ENFORCE(result != nil, "NSString conversion failed."); + return result; +} + /** * Computes the static output shape used to allocate the output tensor. * `inferred_shape` is the inferred shape known at model compile time. It may contain dynamic dimensions (-1). @@ -151,24 +159,79 @@ Status CreateInputFeatureProvider(const std::unordered_map mlmultiarray_buffer_size) { + if (mlmultiarray_buffer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); + } + + const size_t num_elements = array_info.count; + const auto onnx_data_type = tensor_info->data_type; + switch (onnx_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const auto output_data_byte_size = num_elements * sizeof(float); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + const auto output_data_byte_size = num_elements * sizeof(int32_t); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + // For this case, since Coreml Spec only uses int32 for model output while onnx provides + // int64 for model output data type. We are doing a type casting (int32 -> int64) here + // when copying the model to ORT + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), + "CoreML output buffer size and expected output size differ"); + const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; + const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; + std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Output data type is not supported, actual type: ", onnx_data_type); + } + return Status::OK(); +} } // namespace NS_ASSUME_NONNULL_BEGIN @@ -196,7 +259,7 @@ - (Status)predict:(const std::unordered_map&)inputs get_output_tensor_mutable_raw_data_fn API_AVAILABLE_OS_VERSIONS; -@property MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; @end @@ -240,14 +303,17 @@ - (void)dealloc { } - (Status)loadModel { - NSError* error = nil; NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; - NSAssert(modelUrl != nil, @"modelUrl must not be nil"); + if (modelUrl == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); + } + + NSError* error = nil; NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", + [[error localizedDescription] UTF8String]); } compiled_model_path_ = [compileUrl path]; @@ -258,9 +324,9 @@ - (Status)loadModel { : MLComputeUnitsAll; _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; - if (error != NULL) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error Creating MLModel ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + if (error != nil || _model == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", + (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); } return Status::OK(); @@ -272,7 +338,7 @@ - (Status)predict:(const std::unordered_map&)inputs Status status = Status::OK(); ORT_TRY { if (_model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model is not loaded"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); } id input_features; @@ -287,20 +353,20 @@ - (Status)predict:(const std::unordered_map&)inputs if (error != nil) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", - [[error localizedDescription] cStringUsingEncoding:NSUTF8StringEncoding]); + [[error localizedDescription] UTF8String]); } for (const auto& [output_name, output_tensor_info] : outputs) { MLFeatureValue* output_value = - [output_features featureValueForName:[NSString stringWithUTF8String:output_name.c_str()]]; + [output_features featureValueForName:Utf8StringToNSString(output_name.c_str())]; if (output_value == nil) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); } - auto* data = [output_value multiArrayValue]; + MLMultiArray* data = [output_value multiArrayValue]; - const auto coreml_static_output_shape = [&]() { + const auto coreml_static_output_shape = [data]() { InlinedVector result; result.reserve(data.shape.count); for (NSNumber* dim in data.shape) { @@ -324,41 +390,21 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - const void* model_output_buffer = data.dataPointer; - - if (model_output_buffer == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model_output_buffer has no data for ", output_name); - } - - const auto onnx_data_type = output_tensor_info.data_type; - switch (onnx_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - // For this case, since Coreml Spec only uses int32 for model output while onnx provides - // int64 for model output data type. We are doing a type casting (int32 -> int64) here - // when copying the model to ORT - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(data.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - - const auto model_output_span = gsl::span{static_cast(model_output_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(output_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Output data type is not supported, actual type: ", onnx_data_type); + ORT_RETURN_IF_NOT(IsArrayContiguous(data), + "Non-contiguous output MLMultiArray is not currently supported"); + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + }]; + } else { + // disable size check as old API does not return buffer length + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); } + if (!copy_status.IsOK()) + return copy_status; } } } @@ -417,7 +463,7 @@ Status Predict(const std::unordered_map& inputs, return status; } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+ "); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+"); } Status Execution::Predict(const std::unordered_map& inputs, @@ -433,7 +479,7 @@ Status Predict(const std::unordered_map& inputs, } } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+ "); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 18010960e11c8..4553e7ee18913 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -273,7 +273,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma // Opset 9 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Compress); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 19, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MeanVarianceNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Greater); @@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence); @@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -798,7 +798,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, If); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float, GridSample); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where); @@ -958,6 +958,23 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); +#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -1332,7 +1349,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 9 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2383,6 +2400,23 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index ddcc04cf4a45c..9c55d37f550f4 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -115,6 +115,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, int64_t& prefix_dim_size, int64_t& suffix_dim_size, TensorShapeVector& output_shape) override { return onnxruntime::PrepareOutputShape(indices, depth_val, axis, prefix_dim_size, suffix_dim_size, output_shape); } // From cpu/tensor/slice.h (direct) + Status SliceBase__FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) override { + return SliceBase::FlattenOutputDims(input_dimensions, output_dims, starts, ends, steps, p_flattened_input_dims, p_flattened_output_dims); + } + Status SliceBase__PrepareForCompute(gsl::span raw_starts, gsl::span raw_ends, gsl::span raw_axes, @@ -239,6 +245,26 @@ struct ProviderHostCPUImpl : ProviderHostCPU { subgraph_session_state); } + Status WhisperBeamSearch__Compute(const contrib::transformers::WhisperBeamSearch* p, OpKernelContext* ctx) override { + return p->contrib::transformers::WhisperBeamSearch::Compute(ctx); + } + + void BeamSearchParameters__ParseFromAttributes(contrib::transformers::BeamSearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::BeamSearchParameters::ParseFromAttributes(info); + } + + void GreedySearchParameters__ParseFromAttributes(contrib::transformers::GreedySearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::GreedySearchParameters::ParseFromAttributes(info); + } + + void SamplingParameters__ParseFromAttributes(contrib::transformers::SamplingParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::SamplingParameters::ParseFromAttributes(info); + } + + void WhisperBeamSearchParameters__ParseFromAttributes(contrib::transformers::WhisperBeamSearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::WhisperBeamSearchParameters::ParseFromAttributes(info); + } + void GreedySearch__Init(contrib::transformers::GreedySearch* p, const OpKernelInfo& info) override { p->contrib::transformers::GreedySearch::Init(info); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 7d4620f0039eb..8dee1cd620282 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -8,8 +8,13 @@ class LongformerAttentionBase; class AttentionBase; namespace transformers { class BeamSearch; +class WhisperBeamSearch; class GreedySearch; class Sampling; +struct BeamSearchParameters; +struct GreedySearchParameters; +struct SamplingParameters; +struct WhisperBeamSearchParameters; } // namespace transformers } // namespace contrib @@ -63,6 +68,10 @@ struct ProviderHostCPU { virtual Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, int64_t& prefix_dim_size, int64_t& suffix_dim_size, TensorShapeVector& output_shape) = 0; // From cpu/tensor/slice.h + virtual Status SliceBase__FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) = 0; + virtual Status SliceBase__PrepareForCompute(gsl::span raw_starts, gsl::span raw_ends, gsl::span raw_axes, @@ -165,6 +174,15 @@ struct ProviderHostCPU { const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual Status WhisperBeamSearch__Compute(const contrib::transformers::WhisperBeamSearch* p, OpKernelContext* ctx) = 0; + + virtual void BeamSearchParameters__ParseFromAttributes(contrib::transformers::BeamSearchParameters* p, const OpKernelInfo& info) = 0; + + virtual void GreedySearchParameters__ParseFromAttributes(contrib::transformers::GreedySearchParameters* p, const OpKernelInfo& info) = 0; + + virtual void SamplingParameters__ParseFromAttributes(contrib::transformers::SamplingParameters* p, const OpKernelInfo& info) = 0; + + virtual void WhisperBeamSearchParameters__ParseFromAttributes(contrib::transformers::WhisperBeamSearchParameters* p, const OpKernelInfo& info) = 0; // GreedySearch virtual void GreedySearch__Init(contrib::transformers::GreedySearch* p, const OpKernelInfo& info) = 0; diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index 920db5ed34dd1..a93da12ccf595 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -11,11 +11,16 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, ConstantOfShapeDefaultOutputTypes); +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0, + ConstantOfShapeDefaultOutputTypesOpset20); + // pytorch converter uses ConstantOfShape with int64 to create Pad input // https://github.com/pytorch/pytorch/blob/044b519a80459f6787f6723c1c091a18b153d184/torch/onnx/symbolic_opset11.py#L449 ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, int64_t); + } // namespace op_kernel_type_control namespace { @@ -24,6 +29,10 @@ using EnabledOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); +using EnabledOutputTypesOpset20 = + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0); + class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {} @@ -66,13 +75,22 @@ Status ConstantOfShape::Compute(OpKernelContext* ctx) const { } // namespace -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ConstantOfShape, 9, + 19, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()), ConstantOfShape); +ONNX_CPU_OPERATOR_KERNEL( + ConstantOfShape, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", + BuildKernelDefConstraintsFromTypeList()), + ConstantOfShape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index d96ff06e3d6d8..9aa73c714daea 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -23,6 +23,18 @@ using ConstantOfShapeDefaultOutputTypes = uint8_t, uint16_t, uint32_t, uint64_t, bool>; +using ConstantOfShapeDefaultOutputTypesOpset20 = + TypeList< + BFloat16, + MLFloat16, + float, double, +#if !defined(DISABLE_FLOAT8_TYPES) + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, +#endif + int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t, + bool>; + template class ConstantOfShapeBase { protected: diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index b63c0d2161ad5..dfa27f1f44d5a 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -428,4 +428,14 @@ template Status MultinomialComputeShared(AllocatorPtr& alloc, std::default_random_engine& generator, Tensor& Y); +#if !defined(DISABLE_CONTRIB_OPS) +// used by onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +template Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 3192c8573c5c0..1d524a90302e7 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -967,7 +967,7 @@ Status Xor::Compute(OpKernelContext* context) const { }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array() ^ per_iter_bh.EigenInput1().array(); + per_iter_bh.EigenInput0().array() != per_iter_bh.EigenInput1().array(); }}; UntypedBroadcastTwo(*context, funcs, 1.0); diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index 8507d87fd2442..a5d46aff83b50 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -22,11 +23,17 @@ class BatchNormHelper { const Tensor* B, const Tensor* mean, const Tensor* var, - bool is_spatial = true) { + bool is_spatial = true, + bool is_nhwc = false) { const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. - int64_t num_channels = x_dims.size() > 1 ? x_dims[1] : 1; + int64_t num_channels; + if (is_nhwc) { + num_channels = x_dims.size() > 1 ? x_dims[x_dims.size() - 1] : 1; + } else { + num_channels = x_dims.size() > 1 ? x_dims[1] : 1; + } // the first 2 are respectively - N and C. int num_feature_dims = x_dims.size() > 1 ? static_cast(x_dims.size() - 2) : 0; @@ -109,7 +116,7 @@ class BatchNormHelper { return common::Status::OK(); } - static void NormalizeDims(const TensorShape& x_shape, std::vector& new_dims) { + static void NormalizeDims(const TensorShape& x_shape, std::vector& new_dims, bool is_nhwc = false) { new_dims.clear(); auto orig_dims = x_shape.GetDims(); ORT_ENFORCE(orig_dims.size() < 6, @@ -122,13 +129,19 @@ class BatchNormHelper { auto rank = x_shape.NumDimensions(); auto num_samples = rank > 0 ? orig_dims[0] : 1; // NCHW - auto num_channels = rank > 1 ? orig_dims[1] : 1; - auto height = rank > 2 ? orig_dims[2] : 1; + const size_t channel_dim = is_nhwc ? rank - 1 : 1; + const size_t height_dim = is_nhwc ? 1 : 2; + auto num_channels = rank > 1 ? orig_dims[channel_dim] : 1; + auto height = rank > 2 ? orig_dims[height_dim] : 1; int64_t width = 1; - new_dims = {num_samples, num_channels, height, width}; + if (is_nhwc) { + new_dims = {num_samples, height, width, num_channels}; + } else { + new_dims = {num_samples, num_channels, height, width}; + } } }; } // namespace onnxruntime #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index a4d67ec63f0c2..4b3b934834ac8 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -14,6 +14,7 @@ * limitations under the License. */ /* Modifications Copyright (c) Microsoft. */ +// Copyright (c) 2023 NVIDIA Corporation. #pragma once @@ -44,17 +45,19 @@ struct ConvTransposeAttributes : public ConvAttributes { }; Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, - bool dynamic_padding = false, const TensorShape* filter_shape = nullptr) const { + bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, + bool is_nhwc = false) const { const Tensor* X = context->Input(0); const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1); const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape(); const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; - TensorShape input_shape = X->Shape().Slice(2); - const int64_t num_input_channels = X->Shape()[1]; + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(is_nhwc ? 1 : 2, is_nhwc ? rank - 1 : rank); + const int64_t num_input_channels = is_nhwc ? X->Shape()[rank - 1] : X->Shape()[1]; const int64_t N = X->Shape()[0]; - const int64_t num_output_channels_multiplier = F_Shape[1]; + const int64_t num_output_channels_multiplier = is_nhwc ? F_Shape[3] : F_Shape[1]; const int64_t num_output_channels = num_output_channels_multiplier * group; // input validations @@ -85,7 +88,7 @@ struct ConvTransposeAttributes : public ConvAttributes { } TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape)); + ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape, is_nhwc)); TensorShapeVector local_output_padding(output_padding); if (local_output_padding.empty()) { @@ -115,7 +118,7 @@ struct ConvTransposeAttributes : public ConvAttributes { TensorShapeVector Y_dims; ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, - local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims); + local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims, is_nhwc); TensorShape Yshape(Y_dims); Tensor* Y = context->Output(0, Yshape); @@ -137,9 +140,14 @@ struct ConvTransposeAttributes : public ConvAttributes { void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel, const TensorShapeVector& kernel_shape, const TensorShapeVector& p_strides, const TensorShapeVector& p_dilations, const TensorShapeVector& p_output_padding, const int64_t N, - ConvPadVector* p_pads, TensorShapeVector* output_shape_p) const { + ConvPadVector* p_pads, TensorShapeVector* output_shape_p, + bool is_nhwc = false) const { size_t output_shape_size = output_shape.size(); - output_shape_p->insert(output_shape_p->begin(), {N, output_channel}); + if (is_nhwc) { + output_shape_p->insert(output_shape_p->begin(), {N}); + } else { + output_shape_p->insert(output_shape_p->begin(), {N, output_channel}); + } size_t rank = input_shape.NumDimensions(); for (size_t dim = 0; dim < rank; ++dim) { @@ -163,6 +171,9 @@ struct ConvTransposeAttributes : public ConvAttributes { ORT_ENFORCE(dim_size > 0, "Invalid input shape: ", input_shape.ToString()); output_shape_p->push_back(dim_size); } + if (is_nhwc) { + output_shape_p->push_back(output_channel); + } } TensorShapeVector output_padding; diff --git a/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h b/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h index 48e54ac7eeefb..9a2a710fd291a 100644 --- a/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -8,13 +9,16 @@ #include "core/framework/tensor.h" #endif #include +#include namespace onnxruntime { class InstanceNormHelper { public: - static common::Status ValidateInputs(const Tensor* input, const Tensor* scale, const Tensor* B) { - if (input->Shape().NumDimensions() < 3) { + static common::Status ValidateInputs(const Tensor* input, const Tensor* scale, const Tensor* B, + bool is_nhwc = false) { + const auto rank = input->Shape().NumDimensions(); + if (rank < 3) { std::ostringstream ostr; ostr << "Invalid input data: number of dimensions is less than 3: " << input->Shape().NumDimensions(); return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); @@ -24,10 +28,13 @@ class InstanceNormHelper { ostr << "Invalid input scale: number of dimensions is not 1: " << scale->Shape().NumDimensions(); return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } - if (scale->Shape().Size() != input->Shape().GetDims()[1]) { + auto in_dims = input->Shape().GetDims(); + auto in_channels = is_nhwc ? in_dims[rank - 1] : in_dims[1]; + + if (scale->Shape().Size() != in_channels) { std::ostringstream ostr; - ostr << "Mismatch between input data and scale: size of scale != input channel count " - << scale->Shape().Size() << " vs. " << input->Shape().GetDims()[1]; + ostr << "Mismatch between input data and scale: size of scale != input channel count " << scale->Shape().Size() + << " vs. " << in_channels; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } @@ -37,10 +44,10 @@ class InstanceNormHelper { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } - if (B->Shape().Size() != input->Shape().GetDims()[1]) { + if (B->Shape().Size() != in_channels) { std::ostringstream ostr; - ostr << "Mismatch between input data and B: size of B != input channel count " - << B->Shape().Size() << " vs. " << input->Shape().GetDims()[1]; + ostr << "Mismatch between input data and B: size of B != input channel count " << B->Shape().Size() << " vs. " + << in_channels; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } diff --git a/onnxruntime/core/providers/cpu/nn/pool_attributes.h b/onnxruntime/core/providers/cpu/nn/pool_attributes.h index 54f41f09f4b24..118cb4a3ba4bd 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/pool_attributes.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -98,28 +99,34 @@ struct PoolAttributes { TensorShapeVector SetOutputSize(const TensorShape& input_shape, int64_t output_channel, - TensorShapeVector* actual_pads) const { + TensorShapeVector* actual_pads, + bool is_nhwc = false) const { ORT_ENFORCE(input_shape.Size() > 0 || input_shape[0] == 0, "Invalid input shape. Only N can be zero. Got:", input_shape); TensorShapeVector output_dims; int64_t N = input_shape[0]; - InferOutputSize(input_shape.GetDims(), &output_dims, actual_pads); - - output_dims.insert(output_dims.begin(), {N, output_channel}); - + InferOutputSize(input_shape.GetDims(), &output_dims, actual_pads, is_nhwc); + if (is_nhwc) { + output_dims.insert(output_dims.begin(), N); + output_dims.push_back(output_channel); + } else { + output_dims.insert(output_dims.begin(), {N, output_channel}); + } return output_dims; } void InferOutputSize(gsl::span input_dims, TensorShapeVector* output_dims, - TensorShapeVector* actual_pads) const { + TensorShapeVector* actual_pads, + bool is_nhwc = false) const { ORT_ENFORCE(input_dims.size() >= 2); if (global_pooling) { output_dims->assign(input_dims.size() - 2, 1); } else { for (size_t dim = 0; dim < input_dims.size() - 2; ++dim) { int64_t dim_size = 0; - ComputeSizePadDilations(static_cast(input_dims[dim + 2]), + auto spatial_dim = is_nhwc ? input_dims[dim + 1] : input_dims[dim + 2]; + ComputeSizePadDilations(static_cast(spatial_dim), strides[dim], kernel_shape[dim], &actual_pads->at(dim), diff --git a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc index f36b75c508da0..eb245a4c9ba0c 100644 --- a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc +++ b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc @@ -141,14 +141,11 @@ struct TfIdfVectorizer::Impl { Impl(const Impl&) = delete; Impl& operator=(const Impl&) = delete; - void IncrementCount(size_t ngram_id, size_t row_num, - std::vector& frequencies) const { + inline size_t OutputIdToIncrement(size_t ngram_id) const { assert(ngram_id != 0); --ngram_id; assert(ngram_id < ngram_indexes_.size()); - size_t output_idx = row_num * output_size_ + SafeInt(ngram_indexes_[ngram_id]); - assert(output_idx < frequencies.size()); - ++frequencies[output_idx]; + return SafeInt(ngram_indexes_[ngram_id]); } }; @@ -252,77 +249,17 @@ TfIdfVectorizer::TfIdfVectorizer(const OpKernelInfo& info) : OpKernel(info), imp TfIdfVectorizer::~TfIdfVectorizer() = default; -void TfIdfVectorizer::OutputResult(OpKernelContext* ctx, size_t B, const std::vector& frequences) const { - const Impl& impl = *impl_; - std::vector output_dims; - if (B == 0) { - output_dims.push_back(impl.output_size_); - B = 1; // For use in the loops below - } else { - output_dims.push_back(B); - output_dims.push_back(impl.output_size_); - } - - const auto row_size = impl.output_size_; - - TensorShape output_shape(output_dims); - assert(frequences.size() == static_cast(output_shape.Size())); - - auto Y = ctx->Output(0, output_shape); - auto output_data = Y->MutableData(); - const auto& w = impl.weights_; - switch (impl.weighting_criteria_) { - case kTF: { - for (auto f : frequences) { - *output_data++ = static_cast(f); - } - } break; - case kIDF: { - if (!w.empty()) { - const auto* freqs = frequences.data(); - for (size_t batch = 0; batch < B; ++batch) { - for (size_t i = 0; i < row_size; ++i) { - *output_data++ = (*freqs++ > 0) ? w[i] : 0; - } - } - } else { - for (auto f : frequences) { - *output_data++ = (f > 0) ? 1.0f : 0; - } - } - } break; - case kTFIDF: { - if (!w.empty()) { - const auto* freqs = frequences.data(); - for (size_t batch = 0; batch < B; ++batch) { - for (size_t i = 0; i < row_size; ++i) { - *output_data++ = *freqs++ * w[i]; - } - } - } else { - for (auto f : frequences) { - *output_data++ = static_cast(f); - } - } - } break; - case kNone: // fall-through - default: - assert(false); - } -} - -void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_t row_size, - std::vector& frequencies) const { - auto X = ctx->Input(0); - const auto elem_size = X->DataType()->Size(); - - const void* const row_begin = AdvanceElementPtr(X->DataRaw(), row_num * row_size, elem_size); +void TfIdfVectorizer::ComputeImpl(const void* x_data_raw, size_t elem_size, ptrdiff_t row_num, size_t row_size, + bool is_input_string, gsl::span output_data, + std::function&)>& fn_weight) const { + const void* const row_begin = AdvanceElementPtr(x_data_raw, row_num * row_size, elem_size); const void* const row_end = AdvanceElementPtr(row_begin, row_size, elem_size); const auto& impl = *impl_; const auto max_gram_length = impl.max_gram_length_; const auto max_skip_distance = impl.max_skip_count_ + 1; // Convert to distance auto start_ngram_size = impl.min_gram_length_; + size_t output_idx; for (auto skip_distance = 1; skip_distance <= max_skip_distance; ++skip_distance) { auto ngram_start = row_begin; @@ -336,7 +273,7 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ } auto ngram_item = ngram_start; - if (X->IsDataTypeString()) { + if (is_input_string) { const std::string* str_item = reinterpret_cast(ngram_item); const StrMap* str_map = &impl.str_map_; for (auto ngram_size = 1; @@ -349,7 +286,8 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ break; } if (ngram_size >= start_ngram_size && hit->second->id_ != 0) { - impl.IncrementCount(hit->second->id_, row_num, frequencies); + output_idx = impl.OutputIdToIncrement(hit->second->id_); + fn_weight(output_idx, output_data); } str_map = &hit->second->leafs_; } @@ -360,13 +298,14 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ ngram_size <= max_gram_length && ngram_item < ngram_row_end; ++ngram_size, ngram_item = AdvanceElementPtr(ngram_item, skip_distance, elem_size)) { - int64_t val = (X->IsDataType()) ? int64_t{*reinterpret_cast(ngram_item)} : *reinterpret_cast(ngram_item); + int64_t val = (elem_size == 4) ? int64_t{*reinterpret_cast(ngram_item)} : *reinterpret_cast(ngram_item); auto hit = int_map->find(val); if (hit == int_map->end()) { break; } if (ngram_size >= start_ngram_size && hit->second->id_ != 0) { - impl.IncrementCount(hit->second->id_, row_num, frequencies); + output_idx = impl.OutputIdToIncrement(hit->second->id_); + fn_weight(output_idx, output_data); } int_map = &hit->second->leafs_; } @@ -412,31 +351,76 @@ Status TfIdfVectorizer::Compute(OpKernelContext* ctx) const { } assert((num_rows * C) == total_items); - // Frequency holder allocate [B..output_size_] - // and init all to zero - std::vector frequencies; - frequencies.resize(num_rows * impl_->output_size_, 0); + const Impl& impl = *impl_; + TensorShapeVector output_dims; + if (B == 0) { + output_dims.push_back(impl.output_size_); + B = 1; // For use in the loops below + } else { + output_dims.push_back(B); + output_dims.push_back(impl.output_size_); + } + TensorShape output_shape(output_dims); + + auto Y = ctx->Output(0, output_shape); + auto output_data = Y->MutableData(); + const bool is_input_string = X->IsDataTypeString(); if (total_items == 0 || - (X->IsDataTypeString() && impl_->str_map_.empty()) || + (is_input_string && impl_->str_map_.empty()) || ((X->IsDataType() || X->IsDataType()) && impl_->int64_map_.empty())) { // TfidfVectorizer may receive an empty input when it follows a Tokenizer // (for example for a string containing only stopwords). // TfidfVectorizer returns a zero tensor of shape // {b_dim, output_size} when b_dim is the number of received observations // and output_size the is the maximum value in ngram_indexes attribute plus 1. - OutputResult(ctx, B, frequencies); + memset(output_data, 0, static_cast(output_shape.Size() * sizeof(float))); return Status::OK(); } - std::function fn = [this, ctx, C, &frequencies](ptrdiff_t row_num) { - ComputeImpl(ctx, row_num, C, frequencies); - }; + auto x_data_raw = ctx->Input(0)->DataRaw(); + const auto elem_size = X->DataType()->Size(); + int32_t num_batches = std::min(concurrency::ThreadPool::DegreeOfParallelism(ctx->GetOperatorThreadPool()) * 2, num_rows); - concurrency::ThreadPool::TryBatchParallelFor(ctx->GetOperatorThreadPool(), num_rows, std::move(fn), 0); + const auto& w = impl.weights_; + std::function&)> fn_weight; - OutputResult(ctx, B, frequencies); + switch (impl.weighting_criteria_) { + case kTF: + fn_weight = [](size_t i, gsl::span& out) { out[i] += 1.0f; }; + break; + case kIDF: + if (!w.empty()) { + fn_weight = [&w](size_t i, gsl::span& out) { out[i] = w[i]; }; + } else { + fn_weight = [](size_t i, gsl::span& out) { out[i] = 1.0f; }; + } + break; + case kTFIDF: + if (!w.empty()) { + fn_weight = [&w](size_t i, gsl::span& out) { out[i] += w[i]; }; + } else { + fn_weight = [](size_t i, gsl::span& out) { out[i] += 1.0f; }; + } + break; + case kNone: // fall-through + default: + assert(false); + } + + std::function fn = [this, C, output_data, x_data_raw, elem_size, + is_input_string, num_batches, num_rows, &fn_weight](ptrdiff_t batch_num) { + // Frequency holder allocate [B..output_size_] and init all to zero. + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_batches, static_cast(num_rows)); + std::vector frequencies(this->impl_->output_size_); + for (auto row_num = work.start; row_num < work.end; ++row_num) { + auto out = gsl::span(output_data + row_num * this->impl_->output_size_, this->impl_->output_size_); + std::fill(out.begin(), out.end(), 0.0f); + ComputeImpl(x_data_raw, elem_size, row_num, C, is_input_string, out, fn_weight); + } + }; + concurrency::ThreadPool::TrySimpleParallelFor(ctx->GetOperatorThreadPool(), num_batches, std::move(fn)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h index 45db40d893231..14488d91c23e9 100644 --- a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h +++ b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h @@ -19,11 +19,8 @@ class TfIdfVectorizer final : public OpKernel { Status Compute(OpKernelContext* ctx) const override; private: - void ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_t row_size, - std::vector& frequencies) const; - - // Apply weighing criteria and output - void OutputResult(OpKernelContext* ctx, size_t b_dim, const std::vector& frequences) const; + void ComputeImpl(const void* x_data_raw, size_t elem_size, ptrdiff_t row_num, size_t row_size, bool is_input_string, + gsl::span output_data, std::function&)>& fn_weight) const; struct Impl; std::unique_ptr impl_; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b831..21a256eee6f14 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index ce834e371fdef..3c83394fb0bf4 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -688,21 +688,23 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span input_shape, return FastReduceKind::kNone; } -void ValidateCommonFastReduce(const Tensor* axes_tensor) { - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); -} - // template bool CommonFastReduceCopy(OpKernelContext* ctx, TensorShapeVector& input_axes, bool noop_with_empty_axes) { if (ctx->InputCount() == 2) { // second input holds the axes. + // the argument is optional const Tensor* axes_tensor = ctx->Input(1); - ValidateCommonFastReduce(axes_tensor); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - input_axes.insert(input_axes.begin(), data, data + nDims); + + if (axes_tensor != nullptr) { + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + + const auto data_span = axes_tensor->DataAsSpan(); + input_axes.assign(data_span.begin(), data_span.end()); + } else { + input_axes.clear(); + } + if (input_axes.empty() && noop_with_empty_axes) { const Tensor* input = ctx->Input(0); auto* output = ctx->Output(0, input->Shape()); diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc index 4759938cd8250..8064bc0a58cb1 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc @@ -334,27 +334,14 @@ Status SequenceConstruct::Compute(OpKernelContext* context) const { // SplitToSequence -namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0, - float, double, int32_t, int64_t, std::string); -} // namespace op_kernel_type_control - -namespace { -using EnabledSplitToSequenceDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0); -} // namespace - ONNX_CPU_OPERATOR_KERNEL( SplitToSequence, 11, KernelDefBuilder() .TypeConstraint("T", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraints()) .TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes()) - .TypeConstraint("I", std::vector{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("I", BuildKernelDefConstraints()), SplitToSequence); SplitToSequence::SplitToSequence(const OpKernelInfo& info) : OpKernel(info) { @@ -366,29 +353,14 @@ Status SplitToSequence::Compute(OpKernelContext* context) const { const Tensor& input = *context->Input(0); const Tensor* p_split_input = context->Input(1); - Status status; - - if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataType()) - status = ComputeImpl(*context, input, p_split_input); - else if (input.IsDataTypeString()) - status = ComputeImpl(*context, input, p_split_input); - else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SplitToSequence operator does not support ", input.DataType(), " yet"); - - return status; + return ComputeImpl(*context, input, p_split_input); } Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar, int64_t& num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, bool& is_uneven_split, int& num_remaining_splits, - std::vector& split_sizes) const { + InlinedVector& split_sizes) const { auto input_dims = input_shape.GetDims(); const auto num_dimensions = gsl::narrow_cast(input_shape.NumDimensions()); axis = HandleNegativeAxis(axis_, num_dimensions); // handle negative and enforce axis is valid @@ -416,7 +388,7 @@ Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_ // populate split_sizes with the same size for each output num_outputs = split_dim_size; // https://github.com/onnx/onnx/issues/2396 - split_sizes = std::vector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_); + split_sizes = InlinedVector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_); } else { auto split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); if (split_size_sum != split_dim_size) { @@ -453,7 +425,7 @@ static int64_t GetScalarSplitInput(const Tensor& tensor) { return retval; } -static void GetSplitSizesInput(const Tensor& tensor, std::vector& split_sizes) { +static void GetSplitSizesInput(const Tensor& tensor, InlinedVector& split_sizes) { auto num_elems = tensor.Shape().Size(); split_sizes.reserve(onnxruntime::narrow(num_elems)); if (tensor.IsDataType()) { @@ -467,13 +439,8 @@ static void GetSplitSizesInput(const Tensor& tensor, std::vector& split } } -template Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const { - if (!utils::HasType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build."); - } - auto& input_shape = input.Shape(); int64_t num_outputs = 0; int64_t axis = axis_; @@ -484,7 +451,9 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu bool is_split_input_scalar = false; bool is_uneven_split = false; int num_remaining_splits = 0; - std::vector split_sizes; + InlinedVector split_sizes; + const bool is_string_type = input.IsDataTypeString(); + const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size(); // figure out split_scalar or split_sizes if (p_split_input) { @@ -520,8 +489,8 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu // copy dimensions so we can update the selected axis in place auto output_dimensions = input_shape.AsShapeVector(); - int64_t input_offset = 0; - const T* input_data = input.Data(); + SafeInt input_offset = 0; + const void* input_data = input.DataRaw(); for (int i = 0; i < num_outputs; ++i) { // update size of dimension for axis we're splitting on while considering uneven split int split_size; @@ -535,20 +504,50 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context.GetTempSpaceAllocator(&alloc)); Tensor output_tensor(input.DataType(), onnxruntime::TensorShape(output_dimensions), alloc); - T* output_data = output_tensor.MutableData(); - - ::onnxruntime::math::CopyMatrix( - before_dims, // M - split_size * after_dims_excluding_split, // N - static_cast(input_data + input_offset), // A - after_dims_including_split_axis, // lda - static_cast(output_data), // B - split_size * after_dims_excluding_split, // ldb - [](const T* src, T* dst, size_t count) { - copy_data(src, dst, count); - }); - - input_offset += static_cast(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration + void* output_data = output_tensor.MutableDataRaw(); + + const auto M = before_dims; + const auto* A = static_cast(input_data) + static_cast(input_offset * element_size); + const auto lda = after_dims_including_split_axis; + auto* B = output_data; + + const auto N = split_size * after_dims_excluding_split; + const auto ldb = N; + + if (is_string_type) { + const auto* src = reinterpret_cast(A); + auto* dst = reinterpret_cast(B); + if (lda == N) { + copy_data(src, dst, static_cast(M * N)); + } else { + size_t lda_offset = 0; + size_t ldb_offset = 0; + for (size_t idx = 0; idx < static_cast(M); ++idx, + lda_offset += lda, ldb_offset += ldb) { + copy_data(src + lda_offset, dst + ldb_offset, static_cast(N)); + } + } + } else { + if (lda == N) { + // if the data is contiguous, we can just copy the data + const size_t bytes_to_copy = static_cast(N) * static_cast(M) * element_size; + memcpy(B, A, bytes_to_copy); + } else { + // otherwise we need to copy each row + const size_t row_bytes = SafeInt(N) * element_size; + const auto lda_bytes_inc = SafeInt(lda) * element_size; + const auto ldb_bytes_inc = SafeInt(ldb) * element_size; + SafeInt lda_bytes_offset = 0; + SafeInt ldb_bytes_offset = 0; + for (size_t idx = 0; idx < static_cast(M); ++idx, + lda_bytes_offset += lda_bytes_inc, ldb_bytes_offset += ldb_bytes_inc) { + memcpy(reinterpret_cast(B) + static_cast(ldb_bytes_offset), + reinterpret_cast(A) + static_cast(lda_bytes_offset), row_bytes); + } + } + } + + input_offset += SafeInt(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration // if keep_dims = 0, reshape the tensor by dropping the dimension corresponding to 'axis' if (use_keep_dims && keepdims_ == 0) { diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h index 9466d3f0fd108..ccca226fb07ee 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h @@ -60,13 +60,12 @@ class SplitToSequence final : public OpKernel { Status Compute(OpKernelContext* context) const override; private: - template Status ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const; Status PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar, int64_t& num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, bool& is_uneven_split, int& num_remaining_splits, - std::vector& split_sizes) const; + InlinedVector& split_sizes) const; int64_t axis_{}; int64_t keepdims_{1}; const int64_t DEFAULT_LENGTH_EACH_OUTPUT_ = 1; diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc new file mode 100644 index 0000000000000..15900ba553983 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/affine_grid.h" + +#include "core/common/common.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/util/math_cpuonly.h" +#include +#include "Eigen/src/Core/Map.h" +#include +#include "core/common/eigen_common_wrapper.h" + +namespace onnxruntime { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + AffineGrid, \ + 20, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + AffineGrid); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) + +template +void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + + base_grid.resize(static_cast(H * W), 2); + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(j * static_cast(W) + i) << row_vec(i), col_vec(j); + } + } +} + +template +void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast(D), -1, 1); + if (!align_corners) { + slice_vec = slice_vec * (D - 1) / D; + } + + base_grid.resize(static_cast(D * H * W), 3); + for (Eigen::Index k = 0; k < D; k++) { + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(k * static_cast(H * W) + j * static_cast(W) + i) << row_vec(i), col_vec(j), slice_vec(k); + } + } + } +} + +template +void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 2 * 3; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; + const Eigen::Array theta_T(theta_data[2], theta_data[5]); + + auto grid_batch_offset = batch_num * H * W * 2; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 3 * 4; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{ + {theta_data[0], theta_data[1], theta_data[2]}, + {theta_data[4], theta_data[5], theta_data[6]}, + {theta_data[8], theta_data[9], theta_data[10]}}; + const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); + + auto grid_batch_offset = batch_num * D * H * W * 3; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +Status AffineGrid::Compute(OpKernelContext* context) const { + const Tensor* theta = context->Input(0); + const TensorShape& theta_shape = theta->Shape(); + if (theta_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3"); + } + + const Tensor* size = context->Input(1); + const TensorShape& size_shape = size->Shape(); + const int64_t* size_data = size->Data(); + + if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], H = size_data[2], W = size_data[3]; + + TensorShape grid_shape{N, H, W, 2}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_2d(H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4]; + + TensorShape grid_shape{N, D, H, W, 3}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_3d(D, H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size should be 4 or 5."); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.h b/onnxruntime/core/providers/cpu/tensor/affine_grid.h new file mode 100644 index 0000000000000..5ffe660e986f2 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +template +class AffineGrid final : public OpKernel { + public: + AffineGrid(const OpKernelInfo& info) : OpKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + align_corners_ = (align_corners != 0); + } + + Status Compute(OpKernelContext* context) const override; + + private: + bool align_corners_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc index c58a7d8337114..a83ba378d7f1e 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc @@ -11,17 +11,23 @@ namespace onnxruntime { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - GridSample, \ - 16, \ - T, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 16, 19, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); + +#define REGISTER_KERNEL_TYPED_20(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 20, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED_20(float) +REGISTER_KERNEL_TYPED_20(double) // Restore normalized location to actual image location // When align_corners is true: @@ -44,16 +50,15 @@ T GsDenormalize(T n, int64_t length, bool align_corners) { } // Reflect by the near border till within the borders -// Use float for borders to avoid potential issues with integer T template -T GsReflect(T x, float x_min, float x_max) { - float dx = {}; - float fx = static_cast(x); - float range = x_max - x_min; +T GsReflect(T x, T x_min, T x_max) { + T dx = {}; + T fx = static_cast(x); + T range = x_max - x_min; if (fx < x_min) { dx = x_min - fx; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_min + r; } else { @@ -62,7 +67,7 @@ T GsReflect(T x, float x_min, float x_max) { } else if (fx > x_max) { dx = fx - x_max; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_max - r; } else { @@ -75,9 +80,9 @@ T GsReflect(T x, float x_min, float x_max) { // Calculate cubic convolution interpolation coefficients // ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711 -// Use float to avoid potential issues with integer T -void GsGetCubicCoeffs(float x, float coeffs[4]) { - constexpr float cubic_alpha = -0.75f; +template +void GsGetCubicCoeffs(T x, T coeffs[4]) { + constexpr T cubic_alpha = -0.75f; x = std::abs(x); coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha; coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1; @@ -86,9 +91,9 @@ void GsGetCubicCoeffs(float x, float coeffs[4]) { } template -T GsBicubicInterpolate(T p[4][4], float x, float y) { - float v[4] = {}; - float coeffs[4] = {}; +T GsBicubicInterpolate(T p[4][4], T x, T y) { + T v[4] = {}; + T coeffs[4] = {}; GsGetCubicCoeffs(x, coeffs); for (int64_t i = 0; i < 4; i++) { v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; @@ -98,7 +103,7 @@ T GsBicubicInterpolate(T p[4][4], float x, float y) { } template -T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const { +T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const { T pixel = {}; // default 0 if (padding_mode_ == Zeros) { if (c >= 0 && c < W && r >= 0 && r < H) { @@ -116,6 +121,27 @@ T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in return pixel; } +template +T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const { + T pixel = {}; // default 0 + if (padding_mode_ == Zeros) { + if (w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D) { + pixel = image[d * H * W + h * W + w]; + } + } else if (padding_mode_ == Border) { + w = std::clamp(w, 0, W - 1); + h = std::clamp(h, 0, H - 1); + d = std::clamp(d, 0, D - 1); + pixel = image[d * H * W + h * W + w]; + } else { // (padding_mode_ == Reflection) + w = static_cast(GsReflect(static_cast(w), border[0], border[3])); + h = static_cast(GsReflect(static_cast(h), border[1], border[4])); + d = static_cast(GsReflect(static_cast(d), border[2], border[5])); + pixel = image[d * H * W + h * W + w]; + } + return pixel; +} + // When grid sampling, padding is applied before interpolation. // For instance, in bilinear mode and zeros padding-mode, pixel p at actual // image location (-0.5, -0.5) @@ -134,113 +160,203 @@ Status GridSample::Compute(OpKernelContext* context) const { const auto& input_dims = input->Shape(); const auto& grid_dims = grid->Shape(); - if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported"); - } + int64_t data_dims = input_dims.NumDimensions() - 2; + ORT_ENFORCE(static_cast(grid_dims.NumDimensions()) == data_dims + 2, + "grid dimensions must be ", data_dims + 2, "for input dimension of ", data_dims); + + ORT_ENFORCE(grid_dims[grid_dims.NumDimensions() - 1] == data_dims, + "Last dimension of grid: ", grid_dims[grid_dims.NumDimensions() - 1], ", expect ", data_dims); + + ORT_ENFORCE(input_dims.NumDimensions() == 4 || input_dims.NumDimensions() == 5, "Only 4-D or 5-D tensor is supported"); auto N = input_dims[0]; auto C = input_dims[1]; - auto H_in = input_dims[2]; - auto W_in = input_dims[3]; - auto H_out = grid_dims[1]; - auto W_out = grid_dims[2]; ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N); - ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2"); - TensorShape Y_shape = {N, C, H_out, W_out}; - auto& Y = *context->Output(0, Y_shape); - // Return early if the output tensor is going to be of size 0 - if (Y.Shape().Size() == 0) { - return Status::OK(); + if (input_dims.NumDimensions() == 5) { + ORT_ENFORCE(mode_ != Cubic, "Only support GridSample Cubic mode in 4-D cases."); } - // Force float here to avoid possible issue in integer T case - float x_min = -0.5f; - float x_max = W_in - 0.5f; - float y_min = -0.5f; - float y_max = H_in - 0.5f; - - if (align_corners_) { - x_min = 0.f; - x_max = W_in - 1.f; - y_min = 0.f; - y_max = H_in - 1.f; - } - float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - - concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; - for (int64_t n = 0; n < N; n++) { - const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; - concurrency::ThreadPool::TrySimpleParallelFor( - tp, onnxruntime::narrow(C), - [&](std::ptrdiff_t c) { - const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); - T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); - - for (int64_t oy = 0; oy < H_out; oy++) { - for (int64_t ox = 0; ox < W_out; ox++) { - const T* gridpoint = grid_data + (oy * W_out + ox) * 2; - T* Y_gridpoint = Y_data + oy * W_out + ox; - auto nx = gridpoint[0]; // normalized location - auto ny = gridpoint[1]; - auto x = GsDenormalize(nx, W_in, align_corners_); // actual location - auto y = GsDenormalize(ny, H_in, align_corners_); - - if (mode_ == Nearest) { - x = static_cast(std::nearbyintf(static_cast(x))); - y = static_cast(std::nearbyintf(static_cast(y))); - } + if (data_dims == 2) { + // sample 2d; + auto H_in = input_dims[2]; + auto W_in = input_dims[3]; + auto H_out = grid_dims[1]; + auto W_out = grid_dims[2]; + TensorShape Y_shape = {N, C, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } - if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound - if (padding_mode_ == Border) { - // use original border in both align_corner cases - x = std::clamp(x, static_cast(0), static_cast(W_in - 1)); - y = std::clamp(y, static_cast(0), static_cast(H_in - 1)); - } else if (padding_mode_ == Reflection) { - x = GsReflect(x, x_min, x_max); - y = GsReflect(y, y_min, y_max); - } - } // out of bound + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; - if (mode_ == Nearest) { - // x, y are integers in all padding modes - *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); - continue; - } + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + } + T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - if (mode_ == Bilinear) { - int64_t x1 = static_cast(std::floor(x)); - int64_t y1 = static_cast(std::floor(y)); - int64_t x2 = x1 + 1; - int64_t y2 = y1 + 1; - - T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); - T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); - T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); - T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); - - T dx2 = static_cast(x2) - x; - T dx1 = x - static_cast(x1); - T dy2 = static_cast(y2) - y; - T dy1 = y - static_cast(y1); - *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); + + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oy * W_out + ox) * 2; + T* Y_gridpoint = Y_data + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + + T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); + T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); + T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); + T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + } else if (mode_ == Cubic) { + int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox + int64_t y0 = static_cast(std::floor(y)) - 1; + + T p[4][4] = {}; // [H][W] + for (int64_t h = 0; h < 4; h++) { + for (int64_t w = 0; w < 4; w++) { + p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + } + T dx = static_cast(x - x0 - 1); + T dy = static_cast(y - y0 - 1); + *Y_gridpoint = GsBicubicInterpolate(p, dx, dy); + } } - if (mode_ == Bicubic) { - int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox - int64_t y0 = static_cast(std::floor(y)) - 1; - T p[4][4] = {}; // [H][W] - for (int64_t h = 0; h < 4; h++) { - for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + }); + } + } else if (data_dims == 3) { + // sample 3d; + auto D_in = input_dims[2]; + auto H_in = input_dims[3]; + auto W_in = input_dims[4]; + auto D_out = grid_dims[1]; + auto H_out = grid_dims[2]; + auto W_out = grid_dims[3]; + TensorShape Y_shape = {N, C, D_out, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } + + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; + T z_min = -0.5f; + T z_max = D_in - 0.5f; + + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + z_min = 0.f; + z_max = D_in - 1.f; + } + T border[] = {x_min, y_min, z_min, x_max, y_max, z_max}; + + concurrency::ThreadPool* tp = D_out * H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (D_out * H_out * W_out) * 3; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (D_in * H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (D_out * H_out * W_out); + + for (int64_t oz = 0; oz < D_out; oz++) { + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oz * H_out * W_out + oy * W_out + ox) * 3; + T* Y_gridpoint = Y_data + oz * H_out * W_out + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto nz = gridpoint[2]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + auto z = GsDenormalize(nz, D_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + z = static_cast(std::nearbyint(static_cast(z))); + + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid3D(X_data, static_cast(z), static_cast(y), static_cast(x), + D_in, H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t z1 = static_cast(std::floor(z)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + int64_t z2 = z1 + 1; + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + T dz2 = static_cast(z2) - z; + T dz1 = z - static_cast(z1); + + T p111 = PixelAtGrid3D(X_data, z1, y1, x1, D_in, H_in, W_in, border); + T p112 = PixelAtGrid3D(X_data, z1, y1, x2, D_in, H_in, W_in, border); + T p121 = PixelAtGrid3D(X_data, z1, y2, x1, D_in, H_in, W_in, border); + T p122 = PixelAtGrid3D(X_data, z1, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z1 = dy2 * (dx2 * p111 + dx1 * p112) + dy1 * (dx2 * p121 + dx1 * p122); + + T p211 = PixelAtGrid3D(X_data, z2, y1, x1, D_in, H_in, W_in, border); + T p212 = PixelAtGrid3D(X_data, z2, y1, x2, D_in, H_in, W_in, border); + T p221 = PixelAtGrid3D(X_data, z2, y2, x1, D_in, H_in, W_in, border); + T p222 = PixelAtGrid3D(X_data, z2, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z2 = dy2 * (dx2 * p211 + dx1 * p212) + dy1 * (dx2 * p221 + dx1 * p222); + *Y_gridpoint = dz2 * Y_gridpoint_z1 + dz1 * Y_gridpoint_z2; } } - T dx = static_cast(x - x0 - 1); - T dy = static_cast(y - y0 - 1); - *Y_gridpoint = GsBicubicInterpolate(p, static_cast(dx), static_cast(dy)); } } - } - }); + }); + } + } else { + // shall not reach here due to above checks + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only support GirdSample in 4-D or 5-D cases."); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.h b/onnxruntime/core/providers/cpu/tensor/grid_sample.h index 2dd828b3ae3f1..dee0c4701ee21 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.h +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.h @@ -15,37 +15,52 @@ template class GridSample final : public OpKernel { public: explicit GridSample(const OpKernelInfo& info) : OpKernel(info) { - std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); - std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); - align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); - ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic", - "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); - ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", - "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); - if (mode_str == "bicubic") { - mode_ = Bicubic; - } else if (mode_str == "nearest") { - mode_ = Nearest; + int start_version = info.node().SinceVersion(); + if (start_version >= 20) { + std::string mode_str = info.GetAttrOrDefault("mode", "linear"); + if (mode_str == "cubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "linear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic"); + } } else { - mode_ = Bilinear; + std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); + if (mode_str == "bicubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "bilinear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); + } } + + std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); + align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); if (padding_mode_str == "reflection") { padding_mode_ = Reflection; } else if (padding_mode_str == "border") { padding_mode_ = Border; - } else { + } else if (padding_mode_str == "zeros") { padding_mode_ = Zeros; + } else { + ORT_THROW("padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); } } Status Compute(OpKernelContext* context) const override; private: - enum GridSampleInterpolationMode { - Bilinear, + typedef enum { + Linear, + Cubic, Nearest, - Bicubic - }; + } GridSampleInterpolationMode; enum GridSamplePaddingMode { Zeros, @@ -53,9 +68,10 @@ class GridSample final : public OpKernel { Reflection }; - T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const; + T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const; + T PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const; - GridSampleInterpolationMode mode_{Bilinear}; + GridSampleInterpolationMode mode_{Linear}; GridSamplePaddingMode padding_mode_{Zeros}; bool align_corners_{0}; }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index bc99caa8036cf..1b449f46927a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -14,15 +14,38 @@ namespace onnxruntime { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0, - float, double); +using IsInfTypesOpset10 = TypeList; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0, + IsInfTypesOpset10); + +using IsInfTypesOpset20 = + TypeList< + float, + double +#if !defined(DISABLE_FLOAT8_TYPES) + , + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ +#endif + >; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, + kOnnxDomain, + IsInf, + 20, + Input, + 0, + IsInfTypesOpset20); } // namespace op_kernel_type_control class IsInf final : public OpKernel { public: - using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - IsInf, Input, 0); + using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 10, Input, 0); + using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 20, Input, 0); explicit IsInf(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; @@ -30,14 +53,25 @@ class IsInf final : public OpKernel { private: int64_t detect_positive_{1}; int64_t detect_negative_{1}; + int opset_; }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( IsInf, 10, + 19, KernelDefBuilder() .TypeConstraint("T1", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_CPU_OPERATOR_KERNEL( + IsInf, + 20, + KernelDefBuilder() + .TypeConstraint("T1", + BuildKernelDefConstraintsFromTypeList()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); @@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); status = info.GetAttr("detect_negative", &detect_negative_); ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + opset_ = info.node().SinceVersion(); } namespace isinf_internal { @@ -78,6 +113,49 @@ struct ComputeDispatchTarget { } } }; + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto& dims = X.Shape(); + auto input = ConstEigenVectorMap(static_cast(static_cast(X.Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.00 + if (detect_positive && detect_negative) { + output.array() = input.array() == 0b01111100 || input.array() == 0b11111100; + } else if (detect_positive) { + output.array() = input.array() == 0b01111100; + } else if (detect_negative) { + output.array() = input.array() == 0b11111100; + } else { + output.array() = false; + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; +#endif } // namespace isinf_internal Status IsInf::Compute(OpKernelContext* context) const { @@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const { using namespace isinf_internal; - utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; - dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + if (opset_ < 20) { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } else { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc index 33d0f8eb6c1ae..34495e382278a 100644 --- a/onnxruntime/core/providers/cpu/tensor/isnan.cc +++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc @@ -20,10 +20,20 @@ namespace onnxruntime { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ IsNaN); +#define ADD_TYPED_ISNAN_OP_13(data_type) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + IsNaN, \ + 13, 19, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + IsNaN); + #define ADD_TYPED_ISNAN_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ IsNaN, \ - 13, \ + 20, \ data_type, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -33,10 +43,20 @@ namespace onnxruntime { ADD_TYPED_ISNAN_OP_9(float); ADD_TYPED_ISNAN_OP_9(double); ADD_TYPED_ISNAN_OP_9(MLFloat16); +ADD_TYPED_ISNAN_OP_13(float); +ADD_TYPED_ISNAN_OP_13(double); +ADD_TYPED_ISNAN_OP_13(MLFloat16); ADD_TYPED_ISNAN_OP(float); ADD_TYPED_ISNAN_OP(double); ADD_TYPED_ISNAN_OP(MLFloat16); +#if !defined(DISABLE_FLOAT8_TYPES) +ADD_TYPED_ISNAN_OP(Float8E4M3FN); +ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ); +ADD_TYPED_ISNAN_OP(Float8E5M2); +ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ); +#endif + template Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); @@ -70,4 +90,63 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.1111.111 + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = + ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.{01, 10, 11} + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index a8cb74a62e02d..e0cd74343b83d 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -76,9 +76,9 @@ ONNX_CPU_OPERATOR_KERNEL( // e.g. if input shape is { 2, 2, 2, 2, 2 }, output shape is { 2, 2, 1, 2, 2 }, // and the 'steps' value for all dims is 1 except dim-2, then the input shape is coalesced to { 4, 2, 4 } // and the output shape is coalesced to { 4, 1, 4 }. -static void FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, - TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, - TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) { +Status SliceBase::FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) { size_t cur = 0; size_t nxt = 0; while (true) { @@ -131,6 +131,8 @@ static void FlattenOutputDims(gsl::span input_dimensions, gsl::sp ends.resize(cur); steps.resize(cur); } + + return Status::OK(); } // Slice V1-9 & DynamicSlice @@ -138,9 +140,9 @@ Status SliceBase::PrepareForCompute(gsl::span raw_starts, gsl::sp gsl::span raw_axes, SliceOp::PrepareForComputeMetadata& compute_metadata) { ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, compute_metadata)); - FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_); + ORT_RETURN_IF_ERROR(FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); return Status::OK(); } @@ -149,9 +151,9 @@ Status SliceBase::PrepareForCompute(gsl::span raw_starts, gsl::sp gsl::span raw_axes, gsl::span raw_steps, SliceOp::PrepareForComputeMetadata& compute_metadata) { ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, raw_steps, compute_metadata)); - FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_); + ORT_RETURN_IF_ERROR(FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 28e76aca4ea21..1503a87931bcf 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -38,6 +38,10 @@ class SliceBase { TensorShapeVector& input_axes, TensorShapeVector& input_steps); + static Status FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims); + protected: SliceBase(const OpKernelInfo& info, bool dynamic = false) : dynamic_(dynamic) { diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index c13c9d42dd392..0b3ce6f477843 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -239,7 +239,7 @@ class UpsampleBase { if (coordinate_transform_mode_name == "half_pixel_symmetric") { return HALF_PIXEL_SYMMETRIC; } - ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supportted!"); + ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supported!"); } GetOriginalCoordinateFunc GetOriginalCoordinateFromResizedCoordinate( @@ -352,7 +352,7 @@ class UpsampleBase { (scales.size() == 4 && scales[0] == 1 && scales[3] == 1) || scales.size() == 3 || (scales.size() == 5 && scales[0] == 1 && scales[1] == 1), - "'Linear' mode only support:\n" + "'Linear' mode only supports:\n" " * 2-D inputs or\n" " * 3-D inputs ('Bilinear', 'Trilinear') or\n" " * 4-D inputs with the corresponding outermost 2 scale values being 1" diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 57477f167c555..33f2938940e4d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -27,5 +27,91 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() { return &instance; } +const char* cublasGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return ""; + } +} + +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: + return "CUBLAS_COMPUTE_16F"; + case CUBLAS_COMPUTE_32F: + return "CUBLAS_COMPUTE_32F"; + case CUBLAS_COMPUTE_32F_FAST_16F: + return "CUBLAS_COMPUTE_32F_FAST_16F"; + case CUBLAS_COMPUTE_32F_FAST_16BF: + return "CUBLAS_COMPUTE_32F_FAST_16BF"; + case CUBLAS_COMPUTE_32F_FAST_TF32: + return "CUBLAS_COMPUTE_32F_FAST_TF32"; + case CUBLAS_COMPUTE_64F: + return "CUBLAS_COMPUTE_64F"; + default: + return ""; + } +} + +// It must exist somewhere already. +cudaDataType_t ToCudaDataType(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return CUDA_R_32F; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return CUDA_R_16F; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return CUDA_R_16BF; +#if !defined(DISABLE_FLOAT8_TYPES) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + return CUDA_R_8F_E4M3; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return CUDA_R_8F_E5M2; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index fa258961f1155..707099bac3ce0 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -11,6 +11,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" +#include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" @@ -48,6 +49,37 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef BFloat16 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template <> +class ToCudaType { + public: + typedef Float8E4M3FN MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E5M2 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +#endif + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) @@ -152,5 +184,13 @@ class HalfGemmOptions { static HalfGemmOptions instance; }; +const char* cublasGetErrorEnum(cublasStatus_t error); + +const char* CudaDataTypeToString(cudaDataType_t dt); + +const char* CublasComputeTypeToString(cublasComputeType_t ct); + +cudaDataType_t ToCudaDataType(int32_t element_type); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ad892eab3b843..d8a0792209b0f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/common/inlined_containers.h" @@ -15,6 +16,10 @@ #include "contrib_ops/cuda/cuda_contrib_kernels.h" #endif +#ifdef ENABLE_CUDA_NHWC_OPS +#include "core/providers/cuda/cuda_nhwc_kernels.h" +#endif + #ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cuda/cuda_training_kernels.h" #endif @@ -233,6 +238,10 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in : IExecutionProvider{onnxruntime::kCudaExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { +#ifndef ENABLE_CUDA_NHWC_OPS + ORT_ENFORCE(info_.prefer_nhwc == 0, "This build does not support NHWC layout"); +#endif + CUDA_CALL_THROW(cudaSetDevice(info_.device_id)); // must wait GPU idle, otherwise cudaGetDeviceProperties might fail @@ -250,7 +259,7 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; - } else if (info.enable_cuda_graph) { + } else if (info.enable_cuda_graph || info.use_ep_level_unified_stream) { // current cuda graph implementation only works with single stream // use EP level unified stream for all the reqeust CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); @@ -271,6 +280,10 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in #endif } +DataLayout CUDAExecutionProvider::GetPreferredLayout() const { + return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW; +} + CUDAExecutionProvider::~CUDAExecutionProvider() { // clean up thread local context caches { @@ -373,7 +386,7 @@ Status CUDAExecutionProvider::OnRunStart() { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; GetPerThreadContext().CaptureBegin(); } return Status::OK(); @@ -632,51 +645,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -811,12 +827,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -830,45 +840,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -946,22 +917,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); @@ -1016,6 +971,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); @@ -1115,50 +1071,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); @@ -1257,13 +1199,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1287,6 +1229,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -1316,6 +1259,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); + // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast); @@ -1581,51 +1530,51 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1764,12 +1713,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1783,45 +1729,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1895,22 +1802,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1965,6 +1856,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2064,50 +1956,36 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2206,13 +2084,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2236,6 +2114,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2264,6 +2143,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, @@ -2330,6 +2214,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); #endif +#ifdef ENABLE_CUDA_NHWC_OPS + ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaNhwcKernels(kernel_registry)); +#endif + #ifdef ENABLE_TRAINING_OPS ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry)); #endif @@ -2410,7 +2298,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) { const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2428,7 +2316,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { int rank = pads_size / 2; for (int i = 0; i < rank; i++) { if (pads.Get(i) != pads.Get(i + rank)) { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it requires asymmetric padding which the CUDA EP" << " currently does not support"; return true; @@ -2450,7 +2338,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { // symmetric padding. // TODO: Remove this after we have supported asymmetric padding in the CUDA ConvTranspose kernel if (auto_pad_attr == "SAME_UPPER" || auto_pad_attr == "SAME_LOWER") { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it uses the auto_pad attribute which may lead to asymmetric padding which" << " the CUDA EP currently does not support"; return true; @@ -2487,6 +2375,9 @@ std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { InlinedVector candidates; + // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. + InlinedVector tentative_nodes; + const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2494,13 +2385,16 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const auto& node = *p_node; if (!node.GetExecutionProviderType().empty()) { + if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + candidates.push_back(node.Index()); + } continue; } const KernelCreateInfo* cuda_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a CUDA kernel for this node if (cuda_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); continue; } @@ -2520,7 +2414,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); force_inside = !not_supported; } else if ("ConvTranspose" == node.OpType()) { - not_supported = ConvTransposeNeedFallbackToCPU(node); + not_supported = ConvTransposeNeedFallbackToCPU(node, logger); force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); @@ -2529,9 +2423,10 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!force_inside && not_supported) { if (not_supported) { - LOGS_DEFAULT(WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } } else { + tentative_nodes.push_back(node.Index()); candidates.push_back(node.Index()); } } @@ -2539,7 +2434,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index c9e510b7f472b..d0bb2321edf0a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -32,6 +33,8 @@ class CUDAExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream) override; + DataLayout GetPreferredLayout() const override; + const void* GetExecutionHandle() const noexcept override { // The CUDA interface does not return anything interesting. return nullptr; @@ -49,6 +52,12 @@ class CUDAExecutionProvider : public IExecutionProvider { return GetPerThreadContext().CudnnHandle(); } + cudaStream_t ComputeStream() { + // this will return the CUDA EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return stream_; + } + template const T* GetConstOnes(size_t count, cudaStream_t stream) { return GetPerThreadContext().template GetConstOnes(count, stream); @@ -68,6 +77,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConvUseMaxWorkspace() const { return info_.cudnn_conv_use_max_workspace; } bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } + bool IsNHWCPreferred() const { return info_.prefer_nhwc; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index ca88b3474b758..daa3b5ff3d72f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -29,6 +30,8 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; +constexpr const char* kPreferNCHWMode = "prefer_nhwc"; +constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; } // namespace provider_option_names } // namespace cuda @@ -99,6 +102,8 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) + .AddAssignmentToReference(cuda::provider_option_names::kPreferNCHWMode, info.prefer_nhwc) + .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -144,6 +149,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, + {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; @@ -162,6 +169,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, + {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 789b02b0e1d8c..b286f5a9161b0 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -71,6 +72,9 @@ struct CUDAExecutionProviderInfo { cuda::TunableOpInfo tunable_op{}; bool enable_skip_layer_norm_strict_mode{false}; + bool prefer_nhwc{false}; + + bool use_ep_level_unified_stream{false}; static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 58517c2850baf..e3106e41e77c8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -170,10 +170,10 @@ class CudaKernel : public OpKernel { return provider_->PerThreadDefaultCudnnHandle(); } - protected: - template - inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { - return provider_->template GetConstOnes(count, stream); + inline cudaStream_t DefaultCudaStream() const { + // this will return the CUDA EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return provider_->ComputeStream(); } inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { @@ -181,6 +181,12 @@ class CudaKernel : public OpKernel { return gpu_data_transfer->CopyTensorAsync(src, dst, stream); } + protected: + template + inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { + return provider_->template GetConstOnes(count, stream); + } + inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc new file mode 100644 index 0000000000000..f416caecd115f --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#ifdef ENABLE_CUDA_NHWC_OPS + +#include + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/cuda_fwd.h" + +#include "core/providers/cuda/cuda_nhwc_kernels.h" + +namespace onnxruntime::cuda { + +// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout +// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, + Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, + Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16, + AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, + GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, + ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, + BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, + BatchNormalization); + +Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn nhwc_function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : nhwc_function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + return Status::OK(); +} +} // namespace onnxruntime::cuda +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h new file mode 100644 index 0000000000000..0b3a6d5cff0c7 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" + +namespace onnxruntime::cuda { + +onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry); + +} // namespace onnxruntime::cuda diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 5a11f2529f38e..892e8d5329eba 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -217,11 +218,13 @@ struct CUDA_Provider : Provider { info.default_memory_arena_cfg = params->default_memory_arena_cfg; info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; + info.prefer_nhwc = params->prefer_nhwc; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; + info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; return std::make_shared(info); } @@ -243,7 +246,7 @@ struct CUDA_Provider : Provider { cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy; cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream; - // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well. + // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set by C API UpdateCUDAProviderOptionsWithValue() as well. // We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options if (options.find("has_user_compute_stream") != options.end()) { cuda_options.user_compute_stream = internal_options.user_compute_stream; @@ -253,6 +256,8 @@ struct CUDA_Provider : Provider { cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph; cuda_options.cudnn_conv1d_pad_to_nc1d = internal_options.cudnn_conv1d_pad_to_nc1d; cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; + cuda_options.prefer_nhwc = internal_options.prefer_nhwc; + cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e855a515f445a..5f1dbd30f6a3e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->cuda_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->cuda_stream_.GetCpuAllocator()->Info(); + }; +} + struct CudaNotification : public synchronize::Notification { CudaNotification(Stream& s) : Notification(s) { CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, cublasHandle_t external_cublas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this) { if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::cublas_handle_t: return reinterpret_cast(cublas_handle_); break; + case CudaResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 9c62b029b7a36..917702fae08f1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -9,6 +9,13 @@ namespace onnxruntime { +struct CudaStream; + +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(CudaStream&); + CudaStream& cuda_stream_; +}; + struct CudaStream : Stream { CudaStream(cudaStream_t stream, const OrtDevice& device, @@ -36,10 +43,13 @@ struct CudaStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_cuda_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; }; void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index fc02a6509bf24..4df59a98b12e5 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. -#include "cudnn_common.h" +#include + +#include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" #include "core/common/gsl.h" #include "shared_inc/cuda_call.h" @@ -27,7 +30,7 @@ Status CudnnTensor::CreateTensorIfNeeded() { return Status::OK(); } -Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType) { +Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); int rank = gsl::narrow_cast(input_dims.size()); @@ -38,6 +41,10 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat dims[i] = gsl::narrow_cast(input_dims[i]); strides[i] = gsl::narrow_cast(pitches[i]); } + if (is_nhwc) { + std::swap(dims[1], dims[rank - 1]); + std::swap(strides[1], strides[rank - 1]); + } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index ba75ab4f2c029..8a94a334ee688 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -16,7 +17,7 @@ class CudnnTensor final { ~CudnnTensor(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudnnTensor); - Status Set(gsl::span input_dims, cudnnDataType_t dataType); + Status Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc = false); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); // Set 4D tensor format (for NHWC) Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index 899d506f840a2..e4c37c52a1780 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -119,6 +119,161 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { return ComputeDefault(ctx, helper); } +template +Status FuncMatMul( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y) { + typedef typename ToCudaType::MappedType CudaT; + + // Ignore the transpose flag if rank of input being 1. + // Be noted: numpy.transpose on vector does not change anything. + if (A->Shape().NumDimensions() == 1) { + trans_A = false; + } + if (B->Shape().NumDimensions() == 1) { + trans_B = false; + } + + const CudaT cuda_alpha = ToCudaType::FromFloat(alpha); + const CudaT cuda_zero = ToCudaType::FromFloat(0.0f); + + cublasOperation_t cuda_trans_A = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuda_trans_B = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR( + helper.Compute(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_A, trans_batch_B, false)); + const int lda = helper.Lda(trans_A); + const int ldb = helper.Ldb(trans_B); + const int ldc = helper.Ldc(); + int64_t stride_A, stride_B, stride_C, batch_count; + auto& device_prop = cuda_kernel->GetDeviceProp(); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + reinterpret_cast(B->Data()), + ldb, + reinterpret_cast(A->Data()), + lda, + &cuda_zero, + reinterpret_cast(Y->MutableData()), + ldc, + device_prop)); + return Status::OK(); + } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), + trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + reinterpret_cast(B->Data()), + ldb, + stride_B, + reinterpret_cast(A->Data()), + lda, + stride_A, + &cuda_zero, + reinterpret_cast(Y->MutableData()), + ldc, + stride_C, + static_cast(batch_count), + device_prop)); + + return Status::OK(); + } + + // Fill offsets when needed. + helper.FillOffsets(); + CudaKernel::CudaAsyncBuffer A_arrays(cuda_kernel, helper.LeftOffsets().size()); + CudaKernel::CudaAsyncBuffer B_arrays(cuda_kernel, helper.RightOffsets().size()); + CudaKernel::CudaAsyncBuffer Y_arrays(cuda_kernel, helper.OutputOffsets().size()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(A->Data()), helper.LeftOffsets(), A_arrays.CpuSpan()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(B->Data()), helper.RightOffsets(), B_arrays.CpuSpan()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), Y_arrays.CpuSpan()); + ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); + + // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. + // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. + cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) + ? CUBLAS_TF32_TENSOR_OP_MATH + : CUBLAS_DEFAULT_MATH; + CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + + // note that onnxruntime OrtValue is row major, while cublas is column major, + // so swap left/right operands + CUBLAS_RETURN_IF_ERROR(cublasGemmBatchedHelper( + cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + B_arrays.GpuPtr(), + ldb, + A_arrays.GpuPtr(), + lda, + &cuda_zero, + Y_arrays.GpuPtr(), + ldc, + static_cast(helper.OutputOffsets().size()), + device_prop)); + return Status::OK(); +} + +template Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + +template Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + template Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& helper) const { typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/core/providers/cuda/math/matmul.h b/onnxruntime/core/providers/cuda/math/matmul.h index 5ea7b30777402..26de1044eeb23 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.h +++ b/onnxruntime/core/providers/cuda/math/matmul.h @@ -31,5 +31,23 @@ class MatMul final : public CudaKernel { const bool trans_batch_a_; const bool trans_batch_b_; }; + +template +Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 9ede1f8d90ecc..655877f425054 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -99,6 +99,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa // F: float // D: double // O: bool +// X: BFloat16 #define UNARY_OP_VERSIONED_HFD(name, startver, endver) \ UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ @@ -124,12 +125,18 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) +#define UNARY_OP_HFDX(name, ver) \ + UNARY_OP_TYPED(name, ver, MLFloat16) \ + UNARY_OP_TYPED(name, ver, BFloat16) \ + UNARY_OP_TYPED(name, ver, float) \ + UNARY_OP_TYPED(name, ver, double) + #define UNARY_OP_CSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, int8_t) \ UNARY_OP_TYPED(name, ver, int16_t) \ UNARY_OP_TYPED(name, ver, int32_t) \ UNARY_OP_TYPED(name, ver, int64_t) \ - UNARY_OP_HFD(name, ver) + UNARY_OP_HFDX(name, ver) #define UNARY_OP_BWUZCSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, uint8_t) \ diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 1298d53338337..5c3db4a499972 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -53,13 +53,14 @@ UNARY_OPS() // F: float // D: double // O: bool +// X: BFloat16 #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double) -#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \ +#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16) @@ -68,7 +69,7 @@ UNARY_OPS() SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \ - SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) + SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \ @@ -83,8 +84,8 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 4f22b5298a30a..c468971e1e426 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "batch_norm.h" @@ -11,38 +12,38 @@ using namespace std; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 7, 8, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 9, 13, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 14, 14, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 15, \ T, \ kCudaExecutionProvider, \ @@ -50,10 +51,10 @@ namespace cuda { .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - BatchNorm); + BatchNorm); -template -Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const { +template +Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = p_op_kernel_context->Input(0); @@ -62,7 +63,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const const Tensor* mean = p_op_kernel_context->Input(3); const Tensor* var = p_op_kernel_context->Input(4); - ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1)); + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1, NHWC)); const TensorShape& x_shape = X->Shape(); const TensorShape& channel_shape = mean->Shape(); @@ -87,7 +88,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const CudnnTensor data_desc; vector new_dims; BatchNormHelper::NormalizeDims(x_shape, new_dims); - ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type if (X->IsDataType()) { @@ -97,7 +98,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const ORT_RETURN_IF_ERROR(bn_tensor_desc.Set(data_desc, cudnn_batch_norm_mode_)); // Convert the scale, B, mean, var to float - const int64_t C = x_shape.GetDims()[1]; + const int64_t C = x_shape.GetDims()[NHWC ? 3 : 1]; auto f_scale = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); auto f_B = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); auto f_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); @@ -175,13 +176,17 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ - template Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const; +#define SPECIALIZED_COMPUTE(T, DOMAIN, NHWC) \ + REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ + template Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const; -SPECIALIZED_COMPUTE(float) -SPECIALIZED_COMPUTE(double) -SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(float, kOnnxDomain, false) +SPECIALIZED_COMPUTE(double, kOnnxDomain, false) +SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) +#ifdef ENABLE_CUDA_NHWC_OPS +SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.h b/onnxruntime/core/providers/cuda/nn/batch_norm.h index 99da7652a1d24..4eb9fb74d3761 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.h +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -9,7 +10,7 @@ namespace onnxruntime { namespace cuda { -template +template class BatchNorm final : public CudaKernel { public: BatchNorm(const OpKernelInfo& op_kernel_info) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 81db3c4186282..82f3503919237 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -1,38 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. +#include + #include "core/providers/cuda/nn/conv.h" #include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/tensor/transpose.h" namespace onnxruntime { namespace cuda { // Op Set 11 for Conv only update document to clearify default dilations and strides value. // which are already convered by op set 11 cpu versoin, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Conv, \ - kOnnxDomain, \ + DOMAIN, \ 1, 10, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ + Conv); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Conv, \ - kOnnxDomain, \ + DOMAIN, \ 11, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + Conv); + +REGISTER_KERNEL_TYPED(float, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(double, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(MLFloat16, kOnnxDomain, false) -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) +#ifdef ENABLE_CUDA_NHWC_OPS +REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) +#endif template const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { @@ -86,6 +95,39 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); } +template +Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + is_packed = false; + // only layout of weight input is adjusted via PrePack + if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W + if (input_idx == 1) { + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + + InlinedVector perm{0, 2, 3, 1}; + gsl::span permutation(perm.data(), 4); + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[3], + orig_shape[1]}; + W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } + CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); + is_packed = true; + } + } + + return Status::OK(); +} + template Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { // set X @@ -95,7 +137,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.x_data = reinterpret_cast(X->Data()); s_.element_size = X->DataType()->Size(); // set W - const Tensor* W = context->Input(1); + const Tensor* W; + if (!W_) { + W = context->Input(1); + } else { + W = W_.get(); + } const TensorShape& w_shape = W->Shape(); auto w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 07825b93204ca..bcaa4d855b81e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -1,13 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" #include "core/providers/cpu/nn/conv_attributes.h" -#include namespace onnxruntime { @@ -187,8 +190,12 @@ class Conv : public CudaKernel { Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); + is_nhwc_domain_ = info.node().Domain() == kMSInternalNHWCDomain; } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + Status ComputeInternal(OpKernelContext* context) const override; protected: @@ -201,6 +208,8 @@ class Conv : public CudaKernel { mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; static const cudnnConvolutionFwdAlgo_t kAllAlgos[]; + std::unique_ptr W_; + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain }; Status SliceOutUnwantedOutputSection(cudaStream_t stream, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2b3326c528659..55dceaa2698e8 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -1,7 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. +#include + #include "conv_transpose.h" +#include "core/providers/cuda/tensor/transpose.h" // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and @@ -17,35 +21,59 @@ namespace cuda { // Op Set 11 for ConvTranspose only update document to clarify default dilations and strides value. // which are already covered by op set 11 cpu version, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ConvTranspose, DOMAIN, 1, 10, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), ConvTranspose); \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, DOMAIN, 11, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvTranspose); + +REGISTER_KERNEL_TYPED(float, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(double, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(MLFloat16, kOnnxDomain, false) + +#ifdef ENABLE_CUDA_NHWC_OPS +REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { return DoConvTranspose(context, false); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { +template +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, + [[maybe_unused]] PrePackedWeights* prepacked_weights) { + is_packed = false; + // only layout of weight input is adjusted via PrePack + if (NHWC) { // InputTensors::IN_W + if (input_idx == 1) { + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + + InlinedVector perm{0, 2, 3, 1}; + gsl::span permutation(perm.data(), 4); + TensorShapeVector new_dims{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; + W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), + permutation, tensor, *W_); + + if (!status.IsOK()) { + return status; + } + CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); + is_packed = true; + } + } + + return Status::OK(); +} + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); @@ -59,7 +87,12 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", " X: ", X->Shape().ToString().c_str()); } - const Tensor* W = context->Input(1); + const Tensor* W; + if (!W_) { + W = context->Input(1); + } else { + W = W_.get(); + } const TensorShape& w_shape = W->Shape(); TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); @@ -80,8 +113,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) - s_.last_x_dims = gsl::make_span(x_dims); + if (input_dims_changed) s_.last_x_dims = gsl::make_span(x_dims); if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); @@ -89,7 +121,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } ConvTransposeAttributes::Prepare p; - ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC)); auto y_dims = p.Y->Shape().AsShapeVector(); if (x_dimensions == 3) { @@ -102,8 +135,15 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = gsl::make_span(y_dims); - if (w_dims_changed) - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + if (w_dims_changed) { + if (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : @@ -112,31 +152,39 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ if (p.Y->Shape().Size() == 0) { return Status::OK(); } - - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + if (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), - mode, CudnnTensor::GetDataType())); + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType())); if (has_bias) { const auto& b_shape = p.B->Shape(); ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) - b_dims[2 + i] = 1; + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) b_dims[(NHWC ? 1 : 2) + i] = 1; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); } y_data = reinterpret_cast(p.Y->MutableData()); if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search if constexpr (std::is_same::value) @@ -145,19 +193,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), - s_.w_desc, - w_data, - s_.x_tensor, - x_data, - s_.conv_desc, - s_.y_tensor, - y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize)); + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); } @@ -188,26 +225,15 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardData( - GetCudnnHandle(context), - &alpha, - s_.w_desc, - w_data, - s_.x_tensor, - x_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta, - s_.y_tensor, - y_data)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, + x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); if (has_bias) { const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); } } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 165d548d27fa2..77c9d94162b6b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" @@ -12,10 +15,12 @@ namespace onnxruntime { namespace cuda { -template +template class ConvTranspose : public CudaKernel { public: ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){}; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; @@ -23,6 +28,7 @@ class ConvTranspose : public CudaKernel { ConvTransposeAttributes conv_transpose_attrs_; mutable CudnnConvState s_; + std::unique_ptr W_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 4cc560a1178ef..679b8b6b78886 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -104,17 +104,17 @@ __device__ void cuWelfordMuSigma2( const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; + const T* skip_vals = (skip != nullptr) ? skip + i1 * n2 : nullptr; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l + k]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l + k]); } @@ -124,11 +124,11 @@ __device__ void cuWelfordMuSigma2( for (; l < n2; ++l) { U curr = static_cast(lvals[l]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l]); } @@ -301,7 +301,7 @@ namespace { // { // extern __device__ void error(void); // error(); -// return NULL; +// return nullptr; // } // }; // https://github.com/NVIDIA/apex/issues/246 @@ -338,9 +338,7 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* __restrict__ skip_input_bias_add_output) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -350,38 +348,35 @@ __global__ void cuApplyLayerNorm( U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, skip, bias); - const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; - V* ovals = output_vals + i1 * n2; - T* skip_input_bias_add_ovals = (skip_input_bias_add_output != NULL) ? skip_input_bias_add_output + i1 * n2 : NULL; + const int offset = i1 * n2; + const T* lvals = vals + offset; + const T* skip_vals = (skip != nullptr) ? skip + offset : nullptr; + + V* ovals = output_vals + offset; + T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); - - - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[i]); } - if (skip_vals != NULL && skip_broadcasted) { - int skip_i = i % skip_size; - curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor - }else if (skip_vals != NULL){ + if (skip_vals != nullptr) { curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != NULL) ? (U)gamma[i] : (U)1; - U beta_i = (beta != NULL) ? (U)beta[i] : (U)0; + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { ovals[i] = static_cast(gamma_i * c_inv_std_dev * (curr - mu) + beta_i); } - if (skip_input_bias_add_ovals != NULL) { + if (skip_input_bias_add_ovals != nullptr) { skip_input_bias_add_ovals[i] = static_cast(curr); } } @@ -418,9 +413,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* skip_input_bias_add_output) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -452,17 +445,14 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output, - skip_broadcasted, - skip_size); + skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ - const int skip_size); + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index d0d5db8ba3587..e3952eefae35d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,9 +43,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr, - const bool skip_broadcasted = false, - const int skip_size = 0); + T* skip_input_bias_add_output = nullptr); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index e632ef20bce43..8bc96958693bc 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -11,92 +12,99 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define POOLING_KERNEL(op_name, data_type, pool_type, since_version) \ +#define POOLING_KERNEL(op_name, data_type, pool_type, since_version, op_domain, nhwc) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kCudaExecutionProvider, \ + op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10) + Pool); + +#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version, op_domain, nhwc) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + op_name, op_domain, since_version, end_version, data_type, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version, op_domain, nhwc) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version, op_domain, \ + nhwc) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(op_name, op_domain, since_version, end_version, data_type, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10, kOnnxDomain, false) // AveragePool and MaxPool op set 11 only update spec document on default value for dilations and strides. -POOLING_KERNEL(AveragePool, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, double, AveragePool, 11) -POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11) -POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, double, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1) -POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, double, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12) - -POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1) +POOLING_KERNEL(AveragePool, float, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(AveragePool, double, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, double, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, double, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kOnnxDomain, false) + +POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kOnnxDomain, false) + +// NHWC variants +#ifdef ENABLE_CUDA_NHWC_OPS +POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true) + +POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true) + +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10, kMSInternalNHWCDomain, true) +// AveragePool and MaxPool op set 11 only update spec document on default value for dilations +POOLING_KERNEL(AveragePool, float, AveragePool, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1, kMSInternalNHWCDomain, true) +#endif class CudnnPoolingDescriptor final { public: - CudnnPoolingDescriptor() : desc_(nullptr) { - } + CudnnPoolingDescriptor() : desc_(nullptr) {} ~CudnnPoolingDescriptor() { if (desc_ != nullptr) { @@ -108,12 +116,9 @@ class CudnnPoolingDescriptor final { CudnnPoolingDescriptor(const CudnnPoolingDescriptor&) = delete; CudnnPoolingDescriptor& operator=(const CudnnPoolingDescriptor&) = delete; - Status Set(cudnnPoolingMode_t mode, - const gsl::span& kernel_shape, - const gsl::span& pads, - const gsl::span& strides) { - if (!desc_) - CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_)); + Status Set(cudnnPoolingMode_t mode, const gsl::span& kernel_shape, + const gsl::span& pads, const gsl::span& strides) { + if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_)); int rank = gsl::narrow_cast(kernel_shape.size()); InlinedVector window(rank); @@ -128,14 +133,8 @@ class CudnnPoolingDescriptor final { for (int i = 0; i < rank; i++) { stride[i] = gsl::narrow_cast(strides[i]); } - CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper( - desc_, - mode, - CUDNN_PROPAGATE_NAN, - rank, - window.data(), - padding.data(), - stride.data())); + CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(desc_, mode, CUDNN_PROPAGATE_NAN, rank, window.data(), + padding.data(), stride.data())); return Status::OK(); } @@ -146,8 +145,8 @@ class CudnnPoolingDescriptor final { cudnnPoolingDescriptor_t desc_; }; -template -Status Pool::ComputeInternal(OpKernelContext* context) const { +template +Status Pool::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -166,13 +165,12 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - - auto y_dims = pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); + auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. - if (y_shape.Size() == 0) - return Status::OK(); + if (y_shape.Size() == 0) return Status::OK(); auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); @@ -181,12 +179,19 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { // cudnn only takes 4D or 5D input, so pad dimensions if needed - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); + if (NHWC) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); + kernel_shape.insert(kernel_shape.begin() + 1, 1); + strides.insert(strides.begin() + 1, 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + kernel_shape.push_back(1); + strides.push_back(1); + } pads.insert(pads.begin() + kernel_shape.size(), 0); pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); } cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; @@ -203,8 +208,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); @@ -212,24 +217,26 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, context->GetComputeStream()); auto temp_Y = GetScratchBuffer(output_count, context->GetComputeStream()); Impl_Cast(Stream(context), reinterpret_cast(x_data), temp_X.get(), input_count); - CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); + CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), + &beta, y_tensor, temp_Y.get())); Impl_Cast(Stream(context), temp_Y.get(), y_data, output_count); } else { const auto alpha = Consts::One; const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR( + PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); } return Status::OK(); } -template -Status Pool>::ComputeInternal(OpKernelContext* context) const { +template +Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -248,13 +255,12 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); + auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); // special case when there is a dim value of 0 in the shape. - if (Y->Shape().Size() == 0) - return Status::OK(); + if (Y->Shape().Size() == 0) return Status::OK(); auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); @@ -262,20 +268,10 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { Tensor* I = context->Output(1, TensorShape(y_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex( - this->Stream(context), - x_shape, - TensorShape(y_dims), - kernel_shape, - strides, - pads, - this->pool_attrs_.dilations, - this->pool_attrs_.storage_order, - x_data, - y_data, - i_data); + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, + this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { - ORT_RETURN_IF_ERROR((Pool>::ComputeInternal(context))); + ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h index fb223c18d2625..8b5152a1565a9 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.h +++ b/onnxruntime/core/providers/cuda/nn/pool.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -10,7 +11,7 @@ namespace onnxruntime { namespace cuda { -template +template class Pool : public CudaKernel, public PoolBase { public: Pool(const OpKernelInfo& info) : CudaKernel(info), PoolBase(info) {} @@ -18,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Pool> final : public Pool> { +template +class Pool, NHWC> final : public Pool, NHWC> { public: - Pool(const OpKernelInfo& info) : Pool>(info) {} + explicit Pool(const OpKernelInfo& info) : Pool, NHWC>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 2f057d53d5607..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,140 +16,29 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ + 1, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ + REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -725,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { @@ -917,69 +830,76 @@ template std::unique_ptr ReduceCompute { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template __global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) { diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index e9634df205842..806ecfa1aab17 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { input_strides); } +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor) { + TensorShape output_shape = output_tensor->Shape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape) + .IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e61058..a0b12790017f6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu index ad2a44793fe26..1da308811fa48 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu @@ -104,7 +104,7 @@ struct RoundSat { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template <> struct RoundStd { @@ -189,7 +189,7 @@ __global__ void QuantizeLinearKernelAxisSat(const InT* input, OutT* output, cons } } -#endif +#endif // DISABLE_FLOAT8_TYPES template Status CudaQuantizeLinearStd(cudaStream_t stream, const InT* input, OutT* output, const InT* scale, const OutT* zero_point, size_t num_of_element) { diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 3c6d900cee9a4..ab364c274a32d 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -6,6 +6,81 @@ namespace onnxruntime { namespace cuda { +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, // Data tensor's shape. + const gsl::span& shape_span, // Shape that data tensor reshape to. + bool allow_zero) { + TensorShapeVector shape_vector(shape_span.begin(), shape_span.end()); + ReshapeHelper helper(data_tensor_shape, shape_vector, allow_zero); + return TensorShape(shape_vector); +} + +TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { + ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + + return InferReshapeOutputShape( + src->Shape(), + shape->template DataAsSpan(), + allow_zero); +} + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y) { + if (!X) return Status(common::ONNXRUNTIME, common::FAIL, "Missing data tensor to be reshaped."); + if (!shape) return Status(common::ONNXRUNTIME, common::FAIL, "Missing shape tensor for reshaping."); + if (shape->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + } + if (shape->Location().device.Type() != OrtDevice::CPU) { + return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); + } + + const void* src_data = X->DataRaw(); + void* dst_data = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (src_data != dst_data) { + ORT_ENFORCE(ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(cuda_kernel->CopyTensor(*X, *Y, *ctx->GetComputeStream())); + } + + return Status::OK(); +} + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero) { + // TODO(wechi): Study if Tensor can be created as view to existing tensor. + // This feature can refine code for re-sharding and shape broadcasting. + + ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); + ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + + // Calculate output's shape. + auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto Y = Tensor::Create(X->DataType(), dst_shape, alloc); + + // Do reshape. It's equivalent to memcpy. + ORT_ENFORCE(FuncReshape(cuda_kernel, ctx, X, shape, allow_zero, Y.get()).IsOK()); + return Y; +} + ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.h b/onnxruntime/core/providers/cuda/tensor/reshape.h index 01e933e65888f..8f33265071ed3 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.h +++ b/onnxruntime/core/providers/cuda/tensor/reshape.h @@ -10,6 +10,39 @@ namespace onnxruntime { namespace cuda { +// Deduce output shape from ONNX Reshape's inputs. +// +// Arguments: +// data_tensor_shape: The shape of the data tensor (i.e., 1st input). +// shape_span: Elements in the shape tensor (i.e., 2nd input). +// +// Returns: +// The output shape of this Reshape. No symbolic values such as "-1" or "0". +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, + const gsl::span& shape_span, + bool allow_zero); + +TensorShape InferReshapeOutputShape( + const Tensor* src, + const Tensor* shape, + bool allow_zero); + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y); + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero); + class Reshape final : public CudaKernel { public: Reshape(const OpKernelInfo& info) : CudaKernel(info), @@ -18,27 +51,11 @@ class Reshape final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { // Copy the second input tensor into the shape vector - const Tensor* shapeTensor = context->Input(1); - if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); - auto data_span = shapeTensor->template DataAsSpan(); - TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); - if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const TensorShape& X_shape = X->Shape(); - - ReshapeHelper helper(X_shape, shape, allow_zero_); - - Tensor* Y = context->Output(0, TensorShape(shape)); - const void* source = X->DataRaw(); - void* target = Y->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_ENFORCE(context->GetComputeStream()); - ORT_RETURN_IF_ERROR(CopyTensor(*X, *Y, *context->GetComputeStream())); - } - - return Status::OK(); + const Tensor* data_tensor = context->Input(0); + const Tensor* shape_tensor = context->Input(1); + const auto target_shape = InferReshapeOutputShape(data_tensor, shape_tensor, allow_zero_); + Tensor* output_tensor = context->Output(0, target_shape); + return FuncReshape(this, context, data_tensor, shape_tensor, allow_zero_, output_tensor); } private: diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index 440b19bce9fb6..db285ba547b6a 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/tensor/slice.h" #include "core/providers/cpu/tensor/utils.h" +#include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/cuda/tensor/slice_impl.h" namespace onnxruntime { @@ -235,5 +236,58 @@ Status Slice::CallSliceImp(size_t element_size, size_t dimension_count, output_shape); } +Status FuncSlice( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* input, + const std::vector& starts, + const std::vector& ends, + const std::vector& axes, + const std::vector& steps, + Tensor* output) { + gsl::span starts_span = gsl::make_span(starts.data(), starts.size()); + gsl::span ends_span = gsl::make_span(ends.data(), ends.size()); + gsl::span axes_span = gsl::make_span(axes.data(), axes.size()); + gsl::span steps_span = gsl::make_span(steps.data(), steps.size()); + const auto& input_shape = input->Shape(); + const auto input_dimensions = input_shape.GetDims(); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + + ORT_RETURN_IF_ERROR( + SliceOp::PrepareForComputeHelper(starts_span, ends_span, axes_span, steps_span, compute_metadata)); + + ORT_RETURN_IF_ERROR(SliceBase::FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); + + TensorShape output_shape(compute_metadata.output_dims_); + + TArray starts_buffer(compute_metadata.starts_); + TArray steps_buffer(compute_metadata.steps_); + TArray input_strides; + TArray output_strides; + + ORT_RETURN_IF_ERROR(SliceCuda::ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata)); + + ORT_RETURN_IF_ERROR(SliceImpl( + cuda_kernel->Stream(ctx), + input->DataType()->Size(), + gsl::narrow_cast(input_dimensions.size()), + starts_buffer, + steps_buffer, + input_strides, + output_strides, + input->DataRaw(), + output->MutableDataRaw(), + output_shape.Size())); + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 444e37c2167e8..d5c53611d3421 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -38,5 +38,20 @@ class Slice : public CudaKernel, public SliceBase { const TArray& output_strides, OpKernelContext* ctx, const TensorShape& output_shape) const; }; + +Status FuncSlice( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* input, + const std::vector& starts, + const std::vector& ends, + const std::vector& axes, + const std::vector& steps, + Tensor* output); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index b3f92c913a84b..4d98b3a8a145e 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -216,5 +216,6 @@ SPECIALIZED_COMPUTE(int64_t) SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double_t) SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index 0fbd2062ca1b2..d7909454e922c 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -238,6 +238,7 @@ SPECIALIZED_IMPL(int64_t) SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double_t) SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b134c..cdb0338157561 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 04381b6ce355c..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,11 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ede3e7f2c2257..eb068087de4ad 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, executionHandle, true, - &outputShapes, + inputShapesOverrides, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp index d9bfdc3473ca7..b696aefecf664 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp @@ -13,7 +13,7 @@ namespace Dml ComPtr resource; auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( - &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + unmove_ptr(CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT)), D3D12_HEAP_FLAG_NONE, &buffer, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h index 9bf8c58f7a3ec..c4d260b9736df 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h @@ -6,6 +6,11 @@ #include #include "core/providers/dml/OperatorAuthorHelper/Common.h" +template +auto unmove_ptr(T&& t) { + return &static_cast(t); +} + namespace Dml { using namespace OperatorHelper; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 0000000000000..5ff70493252bd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a646..4f7ec188140b5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -103,6 +103,36 @@ namespace DmlGraphFusionHelper ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); } + std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( + const onnxruntime::Graph& graph, + const ONNX_NAMESPACE::TensorProto* initializer) + { + std::unique_ptr unpackedTensor; + std::vector unpackedExternalTensor; + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + + // The tensor may be stored as raw data or in typed fields. + if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); + tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); + tensorByteSize = unpackedExternalTensor.size(); + } + else if (initializer->has_raw_data()) + { + tensorPtr = (std::byte*)(initializer->raw_data().c_str()); + tensorByteSize = initializer->raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); + tensorPtr = unpackedTensor.get(); + } + + return std::make_tuple(std::move(unpackedTensor), std::move(unpackedExternalTensor), tensorPtr, tensorByteSize); + } + void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, @@ -161,32 +191,11 @@ namespace DmlGraphFusionHelper auto iter = initializerNameToInitializerMap.find(subGraphInputArgNames[i]); if (iter != initializerNameToInitializerMap.end()) { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::vector unpackedExternalTensor; - - std::unique_ptr unpackedTensor; - - //auto& initializer = iter->second; auto* initializer = iter->second.first; + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, initializer); - // The tensor may be stored as raw data or in typed fields. - if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) - { - THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); - tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); - tensorByteSize = unpackedExternalTensor.size(); - } - else if (initializer->has_raw_data()) + if (initializer->data_location() != onnx::TensorProto_DataLocation_EXTERNAL && !initializer->has_raw_data()) { - tensorPtr = (std::byte*)(initializer->raw_data().c_str()); - tensorByteSize = initializer->raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); - tensorPtr = unpackedTensor.get(); - // Free the initializer if this is the last usage of it. if (initializerToLastInputIndexMap[initializer] == i) { @@ -501,5 +510,173 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, kvp.second.first); + + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(tensorPtr, tensorByteSize); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a8794..f8f6162aaa1e0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 4813707cdf50c..679738b639ec9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,15 +36,6 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -87,6 +90,7 @@ namespace Dml { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -96,8 +100,10 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, - implicitInputDefs); + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); @@ -190,17 +196,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 0000000000000..5c7b7bff1e370 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) + { + for (const auto& initializer : m_ownedInitializers) + { + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); + } + + // Get the execution provider interfaces + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings) const + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + ORT_THROW_HR_IF(E_UNEXPECTED, static_cast(m_subgraphInputs.size()) != kernelContext->InputCount()); + + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; + + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) + { + const auto& input = kernelContext->RequiredInput(inputIndex); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto shapeIter = m_inferredInputShapes.find(inputName); + + if (shapeIter == m_inferredInputShapes.end()) + { + m_inferredInputShapes[inputName] = input.Shape(); + recompileNeeded = true; + } + else if (shapeIter->second != input.Shape()) + { + shapeIter->second = input.Shape(); + recompileNeeded = true; + } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } + } + + if (recompileNeeded) + { + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) + { + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) + { + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != static_cast(iter->second.NumDimensions())); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); + } + + nodeArg->SetShape(tensorShape); + } + } + + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + ComPtr m_winmlProvider; + ComPtr m_provider; + + mutable std::optional m_persistentResourceBinding; + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + mutable std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + mutable std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; + + // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; + mutable ComPtr m_compiledExecutionPlanOperator; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + { + return new DmlRuntimeFusedGraphKernel( + info, + std::move(indexedSubGraph), + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 0000000000000..d679c5aa5667c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 0000000000000..6318b0d5e2865 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,161 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlGraphFusionHelper.h" + +namespace Dml +{ + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); + + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); + + for (auto nodeIndex : nodeTopologyList) + { + auto* node = graph.GetNode(nodeIndex); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernelLookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> compiledPartitionInfos(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); + } + } + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { + DmlGraphFusionHelper::RegisterDynamicKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + dynamicCpuInputMap, + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 0000000000000..cfa743e1f2b85 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index f97b72aa2d385..8644b8d56a426 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,12 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -636,7 +637,7 @@ namespace Dml bool IsCpuOnDmlOperator(const onnxruntime::Node& node) { - auto cpuOnDmlOperators = std::array{ + auto cpuOnDmlOperators = std::array{ "SequenceAt", "SequenceConstruct", "SequenceEmpty", @@ -659,7 +660,7 @@ namespace Dml bool IsDmlSequenceOperator(const onnxruntime::Node& node) { - auto sequence_ops = std::array{ + auto sequence_ops = std::array{ "ConcatFromSequence" }; @@ -675,7 +676,7 @@ namespace Dml bool IsCustomOpShader(const onnxruntime::Node& node) { - auto custom_ops = std::array{ + auto custom_ops = std::array{ "DFT", "STFT", "GridSample" @@ -1093,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1129,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f25d7..3aaa11cdee479 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428ce99..3fc8f415e5a58 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + std::unordered_map> inferredOutputShapes; + + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder return tensor; }; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else if (inputDef->HasTensorOrScalarShape()) + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + EdgeShapes outputShapes; DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); @@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = outputShapes; } } @@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55557..0039678c00e59 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 18943878ccedc..f7a4743801d81 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,8 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -172,36 +174,68 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } - - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + else { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -379,8 +413,10 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs) + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -443,6 +479,8 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 37d577f647fb5..3bddb5ae16086 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -50,6 +50,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs); + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cdec9..a8a6d6745e908 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e08d2..4deec620fe5fb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2de78..913997ff4ad49 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes @@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h index c63863853fb4e..4bbc8a4b718da 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h @@ -51,16 +51,6 @@ namespace GridSample_float_float #include "GeneratedShaders/grid_sample_float_float.h" } -namespace GridSample_double_float -{ - #include "GeneratedShaders/grid_sample_double_float.h" -} - -namespace GridSample_bool_float -{ - #include "GeneratedShaders/grid_sample_bool_float.h" -} - namespace GridSample_uint16_fp16 { #include "GeneratedShaders/grid_sample_uint16_fp16.h" @@ -101,66 +91,6 @@ namespace GridSample_float_fp16 #include "GeneratedShaders/grid_sample_float_fp16.h" } -namespace GridSample_double_fp16 -{ - #include "GeneratedShaders/grid_sample_double_fp16.h" -} - -namespace GridSample_bool_fp16 -{ - #include "GeneratedShaders/grid_sample_bool_fp16.h" -} - -namespace GridSample_uint16_double -{ - #include "GeneratedShaders/grid_sample_uint16_double.h" -} - -namespace GridSample_uint_double -{ - #include "GeneratedShaders/grid_sample_uint_double.h" -} - -namespace GridSample_uint64_double -{ - #include "GeneratedShaders/grid_sample_uint64_double.h" -} - -namespace GridSample_int16_double -{ - #include "GeneratedShaders/grid_sample_int16_double.h" -} - -namespace GridSample_int_double -{ - #include "GeneratedShaders/grid_sample_int_double.h" -} - -namespace GridSample_int64_double -{ - #include "GeneratedShaders/grid_sample_int64_double.h" -} - -namespace GridSample_fp16_double -{ - #include "GeneratedShaders/grid_sample_fp16_double.h" -} - -namespace GridSample_float_double -{ - #include "GeneratedShaders/grid_sample_float_double.h" -} - -namespace GridSample_double_double -{ - #include "GeneratedShaders/grid_sample_double_double.h" -} - -namespace GridSample_bool_double -{ - #include "GeneratedShaders/grid_sample_bool_double.h" -} - #include #include @@ -471,14 +401,6 @@ class DmlGridSampleOperator : public WRL::Base computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_float::g_GridSample, sizeof(GridSample_float_float::g_GridSample)); break; - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_float::g_GridSample, sizeof(GridSample_double_float::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_float::g_GridSample, sizeof(GridSample_bool_float::g_GridSample)); - break; - default: ORT_THROW_HR(E_INVALIDARG); } @@ -520,63 +442,6 @@ class DmlGridSampleOperator : public WRL::Base computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_fp16::g_GridSample, sizeof(GridSample_float_fp16::g_GridSample)); break; - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_fp16::g_GridSample, sizeof(GridSample_double_fp16::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_fp16::g_GridSample, sizeof(GridSample_bool_fp16::g_GridSample)); - break; - - default: - ORT_THROW_HR(E_INVALIDARG); - } - break; - } - case MLOperatorTensorDataType::Double: - { - switch (inputDataType) - { - case MLOperatorTensorDataType::UInt16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_double::g_GridSample, sizeof(GridSample_uint16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::UInt32: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_double::g_GridSample, sizeof(GridSample_uint_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::UInt64: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_double::g_GridSample, sizeof(GridSample_uint64_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_double::g_GridSample, sizeof(GridSample_int16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int32: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_double::g_GridSample, sizeof(GridSample_int_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int64: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_double::g_GridSample, sizeof(GridSample_int64_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Float16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_double::g_GridSample, sizeof(GridSample_fp16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Float: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_double::g_GridSample, sizeof(GridSample_float_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_double::g_GridSample, sizeof(GridSample_double_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_double::g_GridSample, sizeof(GridSample_bool_double::g_GridSample)); - break; - default: ORT_THROW_HR(E_INVALIDARG); } @@ -901,6 +766,7 @@ class DmlGridSampleOperatorFactory : public WRL::Base kernelDescription.executionType = MLOperatorExecutionType::D3D12; // T1: tensor(float16), tensor(float), tensor(double), tensor(bfloat16) + // tensor(double) is not supported for GPU MLOperatorEdgeTypeConstrant t1Constraint; t1Constraint.typeLabel = "T1"; std::vector t1AllowedEdges diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 60b235880e23f..9f9cfad670919 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -143,7 +143,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe ); DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags(); - m_compiledOperator.Attach(graph.Compile(executionFlags, { batchNormalization }).Detach()); + std::array outputs = { batchNormalization }; + m_compiledOperator.Attach(graph.Compile(executionFlags, outputs).Detach()); } void Compute(const MLOperatorKernelContext& kernelContext) override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa8df..03500d0ee86a9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ class DmlOperatorMultiHeadAttention : public DmlOperator else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index 4f8b5a1bc7fac..e8d5b2746aa13 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -84,7 +84,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase poolingDesc.EndPadding = m_kernel.endPadding; DML_OPERATOR_DESC opDesc = {}; - opDesc.Type = ApiTraits::OperatorDescTraits::type>::Type; + opDesc.Type = ApiTraits::OperatorDescTraits::type>::Type; opDesc.Desc = &poolingDesc; SetDmlOperatorDesc(opDesc, kernelInfo); }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp new file mode 100644 index 0000000000000..30c339b845b36 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +// This operator is easier to understand by looking at a python implementation of the non-interleaved version: +// +// def rotate_half(x): +// """Rotates half the hidden dims of the input.""" +// half_dim = x.shape[-1] // 2 +// x1 = x[..., :half_dim] +// x2 = x[..., half_dim:] +// return np.concatenate((-x2, x1), dim=-1) +// +// +// def apply_rope(x, cos, sin, position_ids): +// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// x_embed = (x * cos) + (rotate_half(x) * sin) +// return x_embed +// +// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache +// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves. +// +// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap +// the sign of every adjacent element. + +namespace Dml +{ +class DmlOperatorRotaryEmbedding : public DmlOperator +{ +public: + DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + // When positionIds is a scalar, it represents the start offset for each sequence + const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; + + Initialize(kernelInfo); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); + const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; + + // The last dimension of the data is the hidden size, so it must be divisible by the head size + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t batchSize = inputDataSizes[1]; + const uint32_t sequenceLength = inputDataSizes[2]; + const uint32_t numHeads = inputDataSizes[3] / headSize; + + const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); + const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; + + if (sequenceLength > maxSequenceLength) + { + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); + + std::vector inputDescs = GetDmlInputDescs(); + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + + // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle + const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; + copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.ScaleBias = &scaleBias; + const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + + // Split the input data into 2 equal parts + const std::vector inputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + + const std::vector splitInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); + const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; + + DML_SPLIT_OPERATOR_DESC splitInputDesc{}; + splitInputDesc.InputTensor = &inputDataDmlTensorDesc; + splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitInputDesc.Axis = interleaved + ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 + : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; + + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + + // Swap the 2 halves and join them together + DML_JOIN_OPERATOR_DESC joinInputDesc{}; + joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.Axis = splitInputDesc.Axis; + joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + + // We generate a sequence from 0 to sequenceLength and add the offset to it + const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; + auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; + TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); + const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); + const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); + const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); + + TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); + const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; + DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; + if (positionIdsIsOffset) + { + ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); + positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; + positionIdsRange.ValueDelta.Int64 = 1; + positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; + + positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; + positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; + positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; + } + const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; + const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; + + // Gather the cos/sin values based on the position ids + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); + const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); + + DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; + gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; + gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; + gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; + gatherCosSinDesc.Axis = 2; + gatherCosSinDesc.IndexDimensions = 2; + const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; + + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data + const std::vector reshapedCosSinShape = interleaved + ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); + + // Create a vector that contains the sign values {-1, 1} + const std::array signTensorShape = {2}; + TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape); + const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{}; + signRange.OutputTensor = &signDmlTensorDesc; + if (dataType == MLOperatorTensorDataType::Float16) + { + const auto valueStart = static_cast(-1.0f); + const auto valueDelta = static_cast(2.0f); + memcpy(signRange.ValueStart.Bytes, reinterpret_cast(&valueStart), sizeof(valueStart)); + memcpy(signRange.ValueDelta.Bytes, reinterpret_cast(&valueDelta), sizeof(valueDelta)); + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16; + } + else + { + ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float); + signRange.ValueStart.Float32 = -1.0f; + signRange.ValueDelta.Float32 = 2.0f; + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32; + } + const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange}; + + // Multiply the broadcasted sign values with the rotated input + const std::vector reshapedSignShape = interleaved + ? std::vector({1, 1, 1, 1, 2}) + : std::vector({1, 1, 1, 2, 1}); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; + mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; + mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; + + // Multiply the non-rotated data with the cos and the rotated data with the sin + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; + mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; + mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; + + // Add the multiplied cos and sin values together + DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; + addDesc.ATensor = &inputOutputDmlTensorDesc; + addDesc.BTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &inputOutputDmlTensorDesc; + const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + ©InputDmlDesc, // Copy the input data to preseve the real input shape + &splitInputDmlDesc, // Split the input data + &gatherCosSinDmlDesc, // Gather cos + &gatherCosSinDmlDesc, // Gather sin + &signRangeDmlDesc, // Generate the signs + + &joinInputDmlDesc, // Join the split data + &mulCosSinDmlDesc, // Multiply cos with the non-rotated data + &mulCosSinDmlDesc, // Multiply sin with the rotated data + &mulSignDmlDesc, // Multiply the sign with the rotated data + &addDmlDesc, // Add the rotated cos and non-rotated sin parts together + }; + + enum NodeIndex : uint32_t + { + copyInputOpIndex, + splitInputOpIndex, + gatherCosOpIndex, + gatherSinOpIndex, + signRangeOpIndex, + + joinInputOpIndex, + mulCosOpIndex, + mulSinOpIndex, + mulSignOpIndex, + addOpIndex, + + // The following indices are optional + positionIdsRangeOpIndex, + positionIdsAddOffsetOpIndex, + }; + + if (positionIdsIsOffset) + { + opDescs.push_back(&positionIdsRangeDmlDesc); + opDescs.push_back(&positionIdsAddOffsetDmlDesc); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; + positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; + positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; + positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; + positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; + positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; + positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; + positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + } + + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + + DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; + cosToGatherEdge.GraphInputIndex = cosCacheIndex; + cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; + cosToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(cosToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; + sinToGatherEdge.GraphInputIndex = sinCacheIndex; + sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; + sinToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(sinToGatherEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.FromNodeIndex = copyInputOpIndex; + inputToSplitEdge.FromNodeOutputIndex = 0; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(inputToSplitEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; + nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex; + nonRotatedDataToMulEdge.FromNodeOutputIndex = 0; + nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex; + nonRotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; + secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; + secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; + firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; + firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(firstHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {}; + cosToMulEdge.FromNodeIndex = gatherCosOpIndex; + cosToMulEdge.FromNodeOutputIndex = 0; + cosToMulEdge.ToNodeIndex = mulCosOpIndex; + cosToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(cosToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; + rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeOutputIndex = 0; + rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; + rotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {}; + sinToMulEdge.FromNodeIndex = gatherSinOpIndex; + sinToMulEdge.FromNodeOutputIndex = 0; + sinToMulEdge.ToNodeIndex = mulSinOpIndex; + sinToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(sinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {}; + rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex; + rotatedSinToMulEdge.FromNodeOutputIndex = 0; + rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex; + rotatedSinToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedSinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {}; + signToMulEdge.FromNodeIndex = signRangeOpIndex; + signToMulEdge.FromNodeOutputIndex = 0; + signToMulEdge.ToNodeIndex = mulSignOpIndex; + signToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(signToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {}; + nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex; + nonRotatedCosToAddEdge.FromNodeOutputIndex = 0; + nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex; + nonRotatedCosToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedCosToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {}; + rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex; + rotatedSinToAddEdge.FromNodeOutputIndex = 0; + rotatedSinToAddEdge.ToNodeIndex = addOpIndex; + rotatedSinToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(rotatedSinToAddEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat index fb087bd800ff0..c5580ee103595 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat @@ -16,8 +16,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_float.h fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /Zi /Od /Fh grid_sample_float_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /Zi /Od /Fh grid_sample_double_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /Zi /Od /Fh grid_sample_bool_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_fp16.h @@ -27,20 +25,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_fp16.h - - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_double.h - ) else ( fxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh stockham.h dxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh stockham_fp16.h @@ -56,8 +40,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_float.h fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_float_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_double_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_bool_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_fp16.h @@ -67,18 +49,5 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_fp16.h - - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_double.h ) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h deleted file mode 100644 index 6d83865ecb7a2..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h +++ /dev/null @@ -1,6398 +0,0 @@ -#if 0 -; -; Note: shader requires additional functionality: -; Double-precision floating point -; -; -; Input signature: -; -; Name Index Mask Register SysValue Format Used -; -------------------- ----- ------ -------- -------- ------- ------ -; no parameters -; -; Output signature: -; -; Name Index Mask Register SysValue Format Used -; -------------------- ----- ------ -------- -------- ------- ------ -; no parameters -; shader hash: eff86a5dd3f8ca652b3700c52570e535 -; -; Pipeline Runtime Information: -; -; -; -; Buffer Definitions: -; -; cbuffer -; { -; -; [116 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [4 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [8 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [4 x i8] (type annotation not present) -; -; } -; -; -; Resource Bindings: -; -; Name Type Format Dim ID HLSL Bind Count -; ------------------------------ ---------- ------- ----------- ------- -------------- ------ -; cbuffer NA NA CB0 cb0 1 -; UAV struct r/w U0 u0 1 -; UAV struct r/w U1 u1 1 -; UAV struct r/w U2 u2 1 -; -target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" -target triple = "dxil-ms-dx" - -%dx.types.Handle = type { i8* } -%dx.types.CBufRet.i32 = type { i32, i32, i32, i32 } -%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } -%"class.RWStructuredBuffer" = type { i32 } -%"class.RWStructuredBuffer" = type { double } -%Constants = type { i32, i32, i32, i32, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, i32 } - -define void @GridSample() { - %1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 2, i32 2, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %2 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 1, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %3 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %4 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %5 = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component) - %6 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 0) ; CBufferLoadLegacy(handle,regIndex) - %7 = extractvalue %dx.types.CBufRet.i32 %6, 0 - %8 = add i32 %7, %5 - %9 = extractvalue %dx.types.CBufRet.i32 %6, 1 - %10 = icmp ult i32 %8, %9 - br i1 %10, label %11, label %3389 - -;