From 4477f57ee3151287a9759bd09d269f0e258a9eda Mon Sep 17 00:00:00 2001 From: Phoebe Chen Date: Thu, 25 Jan 2024 08:27:05 +0800 Subject: [PATCH 01/11] Enable RISC-V 64-bit Cross-Compiling Support for ONNX Runtime on Linux (#19238) ### Description This pull request introduces the necessary changes to enable RISC-V 64-bit cross-compiling support for the ONNX Runtime on Linux. The RISC-V architecture has gained popularity as an open standard instruction set architecture, and this contribution aims to extend ONNX Runtime's compatibility to include RISC-V, thereby broadening the reach of ONNX models to a wider range of devices. ### Motivation and Context RISC-V is a free and open-source instruction set architecture (ISA) based on established RISC principles. It is provided under open licenses without fees. Due to its extensibility and freedom in both software and hardware, RISC-V is poised for widespread adoption in the future, especially in applications related to AI, parallel computing, and data centers. ### Example Build Command ``` ./build.sh --parallel --config Debug --rv64 --riscv_toolchain_root=/path/to/toolchain/root --skip_tests ``` ### Documentation Updates Relevant sections of the documentation will be updated to reflect the newly supported RISC-V 64-bit cross-compilation feature. https://github.com/microsoft/onnxruntime/pull/19239 --------- Signed-off-by: Phoebe Chen --- cmake/external/xnnpack.cmake | 6 +- cmake/onnxruntime_common.cmake | 4 +- cmake/riscv64.toolchain.cmake | 35 +++++++++ tools/ci_build/build.py | 35 ++++++++- tools/scripts/build_riscv64.sh | 129 +++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 cmake/riscv64.toolchain.cmake create mode 100755 tools/scripts/build_riscv64.sh diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index e661aa51bfc17..41f02ce6f22bc 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -6,10 +6,14 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(XNNPACK_USE_SYSTEM_LIBS OFF) +endif() + # BF16 instructions cause ICE in Android NDK compiler if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a) set(XNNPACK_ENABLE_ARM_BF16 OFF) -ENDIF() +endif() # fp16 depends on psimd FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 43d5fa9bdee34..6b8c2560b1714 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -189,6 +189,8 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -198,7 +200,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() -if (ARM64 OR ARM OR X86 OR X64 OR X86_64) +if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) # msvc compiler report syntax error with cpuinfo arm source files # and cpuinfo does not have code for getting arm uarch info under windows diff --git a/cmake/riscv64.toolchain.cmake b/cmake/riscv64.toolchain.cmake new file mode 100644 index 0000000000000..0fda239f9a628 --- /dev/null +++ b/cmake/riscv64.toolchain.cmake @@ -0,0 +1,35 @@ +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT) + +if(NOT RISCV_TOOLCHAIN_ROOT) + message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.") +endif() + +set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT}) +set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot") +set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/") +set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/") +set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/") + +if(RISCV_QEMU_PATH) + message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.") + set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}") +endif() + +set(CMAKE_CROSSCOMPILING TRUE) + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6e5cd7b57e403..186bb699ad209 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -328,6 +328,12 @@ def convert_arg_line_to_args(self, arg_line): help="[cross-compiling] Create Windows x86 makefiles. Requires --update and no existing cache " "CMake setup. Delete CMakeCache.txt if needed", ) + parser.add_argument( + "--rv64", + action="store_true", + help="[cross-compiling] Create riscv64 makefiles. Requires --update and no existing cache " + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( "--arm", action="store_true", @@ -351,6 +357,18 @@ def convert_arg_line_to_args(self, arg_line): action="store_true", help="[cross-compiling] Create ARM64X Binary.", ) + parser.add_argument( + "--riscv_toolchain_root", + type=str, + default="", + help="Path to RISC-V toolchain root dir. e.g. --riscv_toolchain_root=$HOME/riscv-tools/", + ) + parser.add_argument( + "--riscv_qemu_path", + type=str, + default="", + help="Path to RISC-V qemu. e.g. --riscv_qemu_path=$HOME/qemu-dir/qemu-riscv64", + ) parser.add_argument("--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") parser.add_argument("--windows_sdk_version", help="Windows SDK version to use. e.g. 10.0.19041.0") parser.add_argument("--android", action="store_true", help="Build for Android") @@ -1077,6 +1095,19 @@ def generate_build_tree( "-Donnxruntime_DISABLE_OPTIONAL_TYPE=" + ("ON" if disable_optional_type else "OFF"), ] + if args.rv64: + add_default_definition(cmake_extra_defines, "onnxruntime_CROSS_COMPILING", "ON") + if not args.riscv_toolchain_root: + raise BuildError("The --riscv_toolchain_root option is required to build for riscv64.") + if not args.skip_tests and not args.riscv_qemu_path: + raise BuildError("The --riscv_qemu_path option is required for testing riscv64.") + + cmake_args += [ + "-DRISCV_TOOLCHAIN_ROOT:PATH=" + args.riscv_toolchain_root, + "-DRISCV_QEMU_PATH:PATH=" + args.riscv_qemu_path, + "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, "cmake", "riscv64.toolchain.cmake"), + ] + # By default on Windows we currently support only cross compiling for ARM/ARM64 # (no native compilation supported through this script). if args.arm64 or args.arm64ec or args.arm: @@ -1553,7 +1584,9 @@ def generate_build_tree( ] if is_linux() and platform.machine() == "x86_64": # The following flags needs GCC 8 and newer - cflags += ["-fstack-clash-protection", "-fcf-protection"] + cflags += ["-fstack-clash-protection"] + if not args.rv64: + cflags += ["-fcf-protection"] cxxflags = cflags.copy() if args.use_cuda: cudaflags = cflags.copy() diff --git a/tools/scripts/build_riscv64.sh b/tools/scripts/build_riscv64.sh new file mode 100755 index 0000000000000..65681c0b6307d --- /dev/null +++ b/tools/scripts/build_riscv64.sh @@ -0,0 +1,129 @@ +#!/bin/bash +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + + +# The script is a sample for RISC-V 64-bit cross compilation in +# GNU/Linux, and you should ensure that your environment meets +# ORT requirements. You may need to make changes before using it. + +set -e +set -o pipefail + +# Get directory this script is in +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +OS=$(uname -s) + +if [ "$OS" == "Linux" ]; then + LINUX_DISTRO=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') + if [[ "${LINUX_DISTRO}" == "ubuntu" ]] ;then + DIR_OS="Linux" + else + echo "${LINUX_DISTRO} is not supported" + return 1 + fi +else + echo "$OS is not supported" + return 1 +fi + +function cleanup { + if [ -d "$WORK_DIR" ]; then + rm -rf "$WORK_DIR" + fi +} + +# The riscv toolchain, qemu and other platform related settings. +ORT_ROOT_DIR=$DIR/../.. + +PREBUILT_DIR="${ORT_ROOT_DIR}/riscv_tools" + +read -rp "Enter the riscv tools root path(press enter to use default path:${PREBUILT_DIR}): " INPUT_PATH +if [[ "${INPUT_PATH}" ]]; then + PREBUILT_DIR=${INPUT_PATH} +fi +echo "The riscv tool prefix path: ${PREBUILT_DIR}" + +WORK_DIR=$DIR/.prebuilt + +# The prebuit toolchain download from riscv-collab works with Ubuntu. +RISCV_GNU_TOOLCHAIN_URL="https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download" +TOOLCHAIN_VERSION="2023.11.20" +RISCV_TOOLCHAIN_FILE_NAME="riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.11.20-nightly.tar.gz" +RISCV_TOOLCHAIN_FILE_SHA="98d6531b757fac01e065460c19abe8974976c607a8d88631cc5c1529d90ba7ba" + +TOOLCHAIN_PATH_PREFIX=${PREBUILT_DIR} + +execute () { + if ! eval "$1"; then + echo "command:\"$1\" error" + exit 1 + fi +} + +execute "mkdir -p $WORK_DIR" + +# Call the cleanup function when this tool exits. +trap cleanup EXIT + +# Download and install the toolchain from +# https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download +download_file() { + local file_name="$1" + local install_path="$2" + local file_sha="$3" + + echo "Install $1 to $2" + if [[ "$(ls -A "$2")" ]]; then + read -rp "The file already exists. Keep it (y/n)? " replaced + case ${replaced:0:1} in + y|Y ) + echo "Skip download $1." + return + ;; + * ) + rm -rf "$2" + ;; + esac + fi + + echo "Download ${file_name} ..." + mkdir -p "$install_path" + wget --progress=bar:force:noscroll --directory-prefix="${WORK_DIR}" \ + "${RISCV_GNU_TOOLCHAIN_URL}/${TOOLCHAIN_VERSION}/${file_name}" && \ + echo "${file_sha} ${WORK_DIR}/${file_name}" | sha256sum -c - + echo "Extract ${file_name} ..." + tar -C "${install_path}" -xf "${WORK_DIR}/${file_name}" --no-same-owner \ + --strip-components=1 +} + + +read -rp "Install RISCV toolchain(y/n)? " answer +case ${answer:0:1} in + y|Y ) + download_file "${RISCV_TOOLCHAIN_FILE_NAME}" \ + "${TOOLCHAIN_PATH_PREFIX}" \ + "${RISCV_TOOLCHAIN_FILE_SHA}" + ;; + * ) + echo "Skip install RISCV toolchain." + ;; +esac +echo "download finished." + + +# RISC-V cross compilation in GNU/Linux +RISCV_TOOLCHAIN_ROOT=${TOOLCHAIN_PATH_PREFIX} +RISCV_QEMU_PATH=${TOOLCHAIN_PATH_PREFIX}/bin/qemu-riscv64 +python3 "${ORT_ROOT_DIR}"/tools/ci_build/build.py \ + --build_dir "${ORT_ROOT_DIR}/build/${DIR_OS}" \ + --rv64 \ + --parallel \ + --skip_tests \ + --config RelWithDebInfo \ + --cmake_generator=Ninja \ + --riscv_qemu_path="${RISCV_QEMU_PATH}" \ + --riscv_toolchain_root="${RISCV_TOOLCHAIN_ROOT}" "$@" + + From 7dd1f4b8e27f38b55f2430f84ddaae1128bef9f4 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 24 Jan 2024 18:12:04 -0800 Subject: [PATCH 02/11] Pad-18 Cuda implementation (#19211) ### Description Implement Pad-18 for Cuda. ### Motivation and Context Latest models converted by Dynamo fall back on CPU for Pad with performance degradation. This contributes to https://github.com/microsoft/onnx-rewriter/issues/126 --- docs/OperatorKernels.md | 3 +- .../core/providers/cpu/cpu_provider_shared.cc | 8 +- .../core/providers/cpu/cpu_provider_shared.h | 8 +- onnxruntime/core/providers/cpu/tensor/pad.cc | 252 +++++++++--------- .../core/providers/cpu/tensor/padbase.h | 77 +++++- .../providers/cuda/cuda_execution_provider.cc | 38 +-- onnxruntime/core/providers/cuda/tensor/pad.cc | 37 ++- .../providers/rocm/rocm_execution_provider.cc | 26 +- .../provider_bridge_provider.cc | 9 +- 9 files changed, 287 insertions(+), 171 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 31cca232fde34..9d9b266355335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -682,7 +682,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 9c55d37f550f4..bf73c59fb78ca 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); } // From cpu/tensor/padbase.h (direct) - Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + + void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) override { + PadBase::ComputePads(ctx, data_rank, pads_data, pads); + } + // From cpu/tensor/split.h (direct) Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 8dee1cd620282..f33eec4b93e98 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +using PadsVector = InlinedVector; + struct ProviderHostCPU { // From cpu/tensor/gatherbase.h virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0; @@ -44,7 +46,11 @@ struct ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) = 0; // From cpu/tensor/padbase.h - virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0; + virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0; + + virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) = 0; + // From cpu/tensor/split.h virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index fe5267f20712b..912280687e229 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -9,6 +9,8 @@ #include "core/providers/op_kernel_type_control.h" #include "core/util/math.h" +#include + // there's no way to use a raw pointer as the copy destination with std::copy_n // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset // without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. @@ -167,47 +169,7 @@ ONNX_CPU_OPERATOR_KERNEL( using PadsVector = PadBase::PadsVector; -// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) -template -static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, - size_t block_size, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - for (size_t i = 0; i < block_size; i++) { - *output++ = *input; - input += input_delta; - } - input += input_pitch; - } -} - -// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, -// and inputPitch and inputDelta are just a single value added each iteration. -template -static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - *output++ = *input; - input += input_delta; - } -} - -// For constant padding, there is no input, just a size to write the constant to -template -static void PadAxisConstant(T* output, T constant, size_t size) { - if (size == 1) { - *output = constant; - } else if (size == 2) { - *output = constant; - *(output + 1) = constant; - } else { - // This would be faster with SSE instructions. - // That would mean to have an implementation for each type (uint8, uint32, uint64). - T* end = output + size; - for (; output != end;) - *output++ = constant; - } -} - -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { switch (mode) { case Mode::Constant: { // default behavior is fine @@ -242,34 +204,66 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh return Status::OK(); } -// special handling for edge case where the input has one or more dims with value of 0 -template -static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, - const Mode& mode, - const TensorShape& input_shape, - TensorShapeVector& output_dims, - T value) { - TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); - - auto& output_tensor = *ctx->Output(0, output_shape); - - // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty - if (mode == Mode::Constant) { - // we add pads with the default value to all dims including those with a value of 0 - auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); - std::fill_n(output, output_shape.Size(), value); +static void ComputePadWithAxes( + gsl::span pads_tensor_raw_data, + std::function get_axis, + size_t axes_size, + size_t data_rank, + PadsVector& pads) { + for (size_t i = 0; i < axes_size; ++i) { + const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); + pads[axis] = pads_tensor_raw_data[i]; // xi_begin + pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end } +} - return Status::OK(); +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + pads.reserve(2 * data_rank); + const Tensor* axes_tensor = ctx.Input(3); + if (axes_tensor) { + const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); + ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); + + const int64_t num_axes = axes_tensor->Shape().Size(); + ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), + "Pads tensor size should be equal to twice the number of explicitly provided axes."); + + pads.resize(2 * data_rank, 0); + if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) -> int64_t { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } else if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } + } else { + ORT_ENFORCE(pads_data.size() == 2 * data_rank, + "Pads tensor size should be equal to twice the input dimension count "); + pads.assign(pads_data.begin(), pads_data.end()); + } } // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. // For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as // [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. -static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads, - const PadsVector& slices, TensorShapeVector& reshaped_dims) { - size_t dims_count = input_dims.size(); +void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims) { + const size_t dims_count = input_dims.size(); size_t inner_axis = dims_count - 1; size_t inner_size = 1; @@ -288,14 +282,14 @@ static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVec } while (inner_axis-- > 0); reshaped_dims.reserve(inner_axis + 1); - std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims)); + std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims)); // Flatten inner axis. reshaped_dims[inner_axis] = inner_size; } -static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count, - size_t inner_no_pad_size, PadsVector& reshaped_pad) { +void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad) { size_t inner_axis = new_dim_count - 1; std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin()); std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis, @@ -306,6 +300,68 @@ static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size; } +// special handling for edge case where the input has one or more dims with value of 0 +template +static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, + const Mode& mode, + const TensorShape& input_shape, + TensorShapeVector& output_dims, + T value) { + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); + + auto& output_tensor = *ctx->Output(0, output_shape); + + // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty + if (mode == Mode::Constant) { + // we add pads with the default value to all dims including those with a value of 0 + auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); + std::fill_n(output, output_shape.Size(), value); + } + + return Status::OK(); +} + +// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) +template +static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, + size_t block_size, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + for (size_t i = 0; i < block_size; i++) { + *output++ = *input; + input += input_delta; + } + input += input_pitch; + } +} + +// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, +// and inputPitch and inputDelta are just a single value added each iteration. +template +static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + *output++ = *input; + input += input_delta; + } +} + +// For constant padding, there is no input, just a size to write the constant to +template +static void PadAxisConstant(T* output, T constant, size_t size) { + if (size == 1) { + *output = constant; + } else if (size == 2) { + *output = constant; + *(output + 1) = constant; + } else { + // This would be faster with SSE instructions. + // That would mean to have an implementation for each type (uint8, uint32, uint64). + T* end = output + size; + for (; output != end;) + *output++ = constant; + } +} + template static Status PadImpl(OpKernelContext* ctx, const PadsVector& pads, @@ -327,7 +383,7 @@ static Status PadImpl(OpKernelContext* ctx, // Reshape input dims TensorShapeVector reshaped_input_dims; - FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); + PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); // Reshape padding size_t new_dims_count = reshaped_input_dims.size(); @@ -336,8 +392,8 @@ static Status PadImpl(OpKernelContext* ctx, ? reshaped_input_dims[inner_axis] / output_dims[inner_axis] : 0); PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count); - ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); - ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); + PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); + PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); TensorShapeVector reshaped_output_dims = reshaped_input_dims; TensorShapeVector input_starts; @@ -575,20 +631,6 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) { return result; } -template -void ComputePadWithAxes( - gsl::span pads_tensor_raw_data, - gsl::span axes_tensor_raw_data, - size_t data_rank, - PadsVector& pads) { - size_t axes_size = axes_tensor_raw_data.size(); - for (size_t i = 0; i < axes_size; ++i) { - int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank); - pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin - pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end - } -} - Status Pad::Compute(OpKernelContext* ctx) const { const Tensor& input_tensor = *ctx->Input(0); MLDataType data_type = input_tensor.DataType(); @@ -608,48 +650,14 @@ Status Pad::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), "Pads tensor should be a 1D tensor of shape [2 * num_axes] " "or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - pads.reserve(2 * data_rank); - - const Tensor* axes_tensor = ctx->Input(3); - if (axes_tensor) { - const auto& axes_tensor_dims = axes_tensor->Shape().GetDims(); - ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor "); - int64_t axes_size = axes_tensor_dims[0]; - - pads.resize(2 * data_rank, 0); - if (axes_tensor->IsDataType()) { - const int32_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } else if (axes_tensor->IsDataType()) { - const int64_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } - } else { - ORT_ENFORCE(pads_size == 2 * data_rank, - "Pads tensor size should be equal to twice the input dimension count "); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } - } + + const auto pads_data = pads_tensor.DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(*ctx, data_rank, pads_data, pads); // Separate out any negative pads into the slices array - slices.assign(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); value.u64 = 0U; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h index d869ed1a6dda2..43f9cbfc9f9a4 100644 --- a/onnxruntime/core/providers/cpu/tensor/padbase.h +++ b/onnxruntime/core/providers/cpu/tensor/padbase.h @@ -19,9 +19,80 @@ class PadBase { // Pads and slices are usually about twice the shapes involved using PadsVector = InlinedVector; - // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions - // in the input_shape with a value of zero. - static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape); + // The following several functions are shared among the providers + + /// + /// Handle the case when the input shape has zero dim values. + /// Depending on the mode, the input dim with zero value must match the output dim value. + /// + /// + /// Padding mode enum value + /// actual input shape + /// output_shape + /// Error if current mode padding can not be achieved with zero dim values + static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape); + + /// + /// Compute Pads by applying axes if specified otherwise copy the supplied pads. + /// + /// The function queries optional axes input (since version 18) and if present, + /// applies it as a mask to the pads. If axes is not present, the pads are copied as is. + /// If axes are present, they are used as a mask over pads, so only those axes are being padded. + /// + /// kernel context to query axes input + /// input rank + /// pads data from pads input + /// resulting pads + static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads); + + /// + /// Separates negative pad values to slices and zeros them out in original pads. + /// Leaving the rest of slices values as zero. + /// + /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider + /// interfaces. + /// + /// pad values + /// slices output + static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) { + slices.assign(pads.size(), 0); + for (size_t index = 0, lim = pads.size(); index < lim; index++) { + if (pads[index] < 0) { + slices[index] = pads[index]; + pads[index] = 0; + } + } + } + + // End provider shared + + /// + /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis. + /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as + /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. + /// + /// This is a helper function pads are expected to be twice the rank + /// + /// original input dims + /// pad values + /// slices + /// result dims + static void FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims); + + /// + /// Used after the inner shape is flattened, so we can apply this function to pads and slices + /// to reshape them as well. + /// + /// pads + /// original dim count + /// expected flattended dim count + /// is the left most dimension that was flattened. + /// In the example above, that would be 224, reverse computed from 224*3 + /// resulting reshaped pads or slices + static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad); protected: PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 644bcaaa24cd4..3fc4ed355a12b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1121,10 +1121,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1269,6 +1269,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2008,10 +2012,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2091,13 +2095,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2150,11 +2147,22 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 4584e5fd8272c..bdd6567d2ef34 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -29,15 +29,27 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); @@ -94,28 +106,15 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { if (is_dynamic_) { const Tensor& pads_tensor = *ctx->Input(1); const auto pads_tensor_dims = pads_tensor.Shape().GetDims(); - ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()), - "Pads tensor should be an INT64 tensor"); ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), - "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]"); + "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), - "Pads tensor size should be equal to twice the input dimension count "); + const auto pads_data = pads_tensor.DataAsSpan(); + + PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads); - pads.reserve(2LL * dimension_count); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } // Separate out any negative pads into the slices array - slices.resize(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); T raw_value{}; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7bec337a6be4..fff3d14b763d5 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1158,10 +1158,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1298,6 +1298,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2088,10 +2093,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2228,6 +2233,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a3155fe6b86cf..e1d0e310425c5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape, const TensorShape& indice_shape, const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); } -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); } +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { + return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); +} + +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads); +} Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors, Prepare& p) const { From 2b87dd373a3567c2c426e2f090b201b8b051a346 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 25 Jan 2024 10:16:41 +0800 Subject: [PATCH 03/11] [ORTModule] Remove Mod from Hash to Avoid Conflict for Triton Code-gen (#19256) Remove mod (10**8) from hash to avoid conflict for Triton code-gen. --- .../python/training/ort_triton/kernel/_mm.py | 20 +++++++++---------- .../training/ort_triton/triton_op_executor.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py index ed92923589d48..a3681a13699a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py @@ -11,7 +11,7 @@ import torch from .._cache import ModuleCache, PyCodeCache -from .._utils import next_power_of_2 +from .._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name): def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int: - return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") def _gen_mm_module( dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("mm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) src_code = _MM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -333,7 +333,7 @@ def _gen_gemm_key( alpha: float, beta: float, ) -> int: - return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8) + return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") def _gen_gemm_module( @@ -348,7 +348,7 @@ def _gen_gemm_module( alpha: float, beta: float, ) -> Tuple[str, ModuleType]: - func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}" + func_name = gen_unique_name("gemm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) kwargs["stride_cm"] = stride_cm kwargs["stride_cn"] = stride_cn @@ -356,7 +356,7 @@ def _gen_gemm_module( src_code = _GEMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -364,13 +364,13 @@ def _gen_gemm_module( def _gen_bmm_key( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> int: - return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") def _gen_bmm_module( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("bmm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) batch = batch_a if batch_a >= batch_b else batch_b kwargs["stride_aq"] = m * k if batch_a == batch else 0 @@ -379,7 +379,7 @@ def _gen_bmm_module( src_code = _BMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 1fe61750e651e..f16abc71251ed 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -67,7 +67,7 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument - return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) + return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: From 1c92e56dc0f906a43128e2f0c4c6729349aac92b Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 25 Jan 2024 22:28:47 +0800 Subject: [PATCH 04/11] [Cuda] Refactor GroupNorm (#19146) Split GroupNorm implementation into multiple files, to make ROCm EP can reuse cuda code. Related PR: https://github.com/microsoft/onnxruntime/pull/19158 --------- Co-authored-by: Peixuan Zuo --- cmake/onnxruntime_rocm_hipify.cmake | 3 + .../cuda/diffusion/group_norm_common_base.cc | 101 ++++ .../cuda/diffusion/group_norm_common_base.h | 186 ++++++ .../cuda/diffusion/group_norm_impl.cu | 529 +----------------- .../cuda/diffusion/group_norm_impl_kernel.cuh | 355 ++++++++++++ 5 files changed, 653 insertions(+), 521 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index f70961a66329a..d485abe6bb1a6 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -47,6 +47,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" + "diffusion/group_norm_impl_kernel.cuh" + "diffusion/group_norm_common_base.h" + "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc new file mode 100644 index 0000000000000..5dec690528847 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread) { + return NextSize(channels_per_block) / channels_per_thread; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; + } + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; + } + } + } + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h new file mode 100644 index 0000000000000..84f3403b8d5ae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -0,0 +1,186 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/providers/cuda/cuda_common.h" +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread); + +static inline int32_t DivUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor); + +int GetChannelsPerBlock(int num_channels, int num_groups); + +template +struct GroupNormNHWCParams { + // The output buffer. Shape is (n, h, w, c). + T* dst; + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). + T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + + // The gamma scaling factor. + float const* gamma; + + // The beta term to add in GN. + float const* beta; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; + + // The number of instances in the batch. + int32_t n; + + // The height and width of each activation map. + int32_t h; + int32_t w; + + // Number of channels. + int32_t c; + + // Number of groups. + int32_t groups; + + // Do we apply the SiLU activation function? + bool use_silu; + + // Precomputed values and parameters to control the execution of the kernels. + + // Number of activations per instance (h * w) + int32_t hw; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; + // The precomputed number of groups per block. + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; + + GroupNormNHWCParams(T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); + } + + this->use_silu = use_silu; + this->dst = output; + this->add_out = add_out; + this->src = input; + this->skip = skip; + this->bias = bias; + this->gamma = gamma; + this->beta = beta; + this->group_sum_buffer = reinterpret_cast(workspace); + this->n = batch_size; + this->h = height; + this->w = width; + this->c = num_channels; + this->groups = num_groups; + this->hw = this->h * this->w; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(this->hw, max_blocks_per_hw); + this->hw_per_block = DivUp(this->hw, blocks_per_hw); + + this->channels_per_block = channels_per_block; + this->channels_per_group = channels_per_group; + this->hwc = this->hw * this->c; + this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); + this->groups_per_block = channels_per_block / this->channels_per_group; + this->epsilon = epsilon; + this->broadcast_skip = broadcast_skip; + + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + this->skip_workspace = (this->add_out != nullptr) ? this->add_out : this->dst; + + this->threads_per_block = GetThreadsPerBlock(channels_per_block, CHANNELS_PER_THREAD); + } +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 48b161552ce0c..d7b2cc2379f4f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -27,6 +27,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh" using namespace onnxruntime::cuda; @@ -34,329 +36,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -namespace { - -// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. -constexpr static int32_t CHANNELS_PER_THREAD = 2; - -constexpr static int kSizes[] = {128, 256, 320, 384, 512}; -constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); -constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; - -int NextSize(int x) { - for (size_t i = 0; i < kNumOfSizes; ++i) { - if (x <= kSizes[i]) { - return kSizes[i]; - } - } - - return x; -} -} // namespace - -static inline int32_t DivUp(int32_t m, int32_t n) { - return (m + n - 1) / n; -} - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sum_sq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -struct GroupNormNHWCParams { - // The output buffer. Shape is (n, h, w, c). - T* dst; - - // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). - T* add_out; - - // The input buffer. Shape is (n, h, w, c). - T const* src; - - // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). - T const* skip; - - // Optional input buffer for bias tensor. Shape is (c). - T const* bias; - - // The gamma scaling factor. - float const* gamma; - - // The beta term to add in GN. - float const* beta; - - // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. - float* group_sum_buffer; - - // The number of instances in the batch. - int32_t n; - - // The height and width of each activation map. - int32_t h; - int32_t w; - - // Number of channels. - int32_t c; - - // Number of groups. - int32_t groups; - - // Do we apply the SiLU activation function? - bool use_silu; - - // Precomputed values and parameters to control the execution of the kernels. - - // Number of activations per instance (h * w) - int32_t hw; - - // Number of activations per block - int32_t hw_per_block; - - // Number of channels per block in the C dimension. - int32_t channels_per_block; - - // Number of channels per group in the C dimension. - int32_t channels_per_group; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hw*channels_per_group to compute mean of a group. - float inv_hw_channels_per_group; - // The precomputed number of groups per block. - int32_t groups_per_block; - - // Number of threads per block - int32_t threads_per_block; - - // Epsilon to get stable variance in normalization. - float epsilon; - - // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. - bool broadcast_skip; - - // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. - T* skip_workspace; -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); - -template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template -inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); - h2 = h2 + b; - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - float2 b = *reinterpret_cast(&bias[bias_offset]); - f2.x += s.x + b.x; - f2.y += s.y + b.y; - - *reinterpret_cast(&add_out[offset]) = f2; - - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template -inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - f2.x += s.x; - f2.y += s.y; - *reinterpret_cast(&add_out[offset]) = f2; - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { - // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage temp_storage; - - // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. - __shared__ float2 smem[THREADS_PER_BLOCK]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; - // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - // The sums. - float sum = 0.F; - float sum_sq = 0.F; - - // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { - // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) - const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; - - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); - } - } - } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); - } - } - } - } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); - } - } - - // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; - // The channel in the group. - int32_t cj = ci % params.channels_per_group; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - - // Do the segmented scan. InclusiveScan is not deterministic. - GroupSums out; - BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced stores later). - // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { - smem[gi] = make_float2(out.sum, out.sum_sq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { - return; - } - - // The global group index. - // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; - - if (gj < params.groups) { - float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); - } -} - template void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -390,102 +69,6 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) } } -template -__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu); - -template <> -__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); -} - -template <> -__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast(&dst[offset]) = f2; -} - -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on. - int32_t gi = ci / params.channels_per_group; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; - } - - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; - // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); - // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); - - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } -} - template void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -517,60 +100,6 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea } } -int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { - int32_t max_divisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { - max_divisor = divisor1; - } - if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { - max_divisor = divisor2; - } - } - } - return max_divisor; -} - -// Find proper channels per block based on a cost function: The cost is number of channels corresponding to -// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has -// work to do so it is ideal case. -int FindChannelsPerBlock(int num_channels, int channels_per_group) { - int min_cost = -1; - int best_candidate = -1; - for (size_t i = kNumOfSizes; i > 0; --i) { - if (kSizes[i - 1] < channels_per_group) { - break; - } - - int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; - int blocks = (num_channels + channels_per_block - 1) / channels_per_block; - int cost = blocks * kSizes[i - 1] - num_channels; - if (cost == 0) { - return channels_per_block; - } - - if (min_cost == -1 || cost < min_cost) { - min_cost = cost; - best_candidate = channels_per_block; - } - } - - return best_candidate; -} - -int GetChannelsPerBlock(int num_channels, int num_groups) { - int32_t channels_per_group = num_channels / num_groups; - int32_t channels_per_block = channels_per_group; - if (channels_per_group < kMaxSize / 2) { - channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); - } - return channels_per_block; -} - template Status LaunchGroupNormKernel( cudaStream_t stream, @@ -591,19 +120,13 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params; - - int32_t channels_per_group = num_channels / num_groups; - // channels_per_block is computed in PrePack. - // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { - channels_per_block = GetChannelsPerBlock(num_channels, num_groups); - } + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + batch_size, num_channels, height, width, num_groups, use_silu, + broadcast_skip, channels_per_block); - // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases - if (channels_per_block % channels_per_group != 0 || - channels_per_block > kMaxSize || - (channels_per_group % CHANNELS_PER_THREAD != 0)) { + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "GroupNorm in CUDA does not support the input: n=", batch_size, " h=", height, @@ -612,42 +135,6 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } - params.use_silu = use_silu; - params.dst = output; - params.add_out = add_out; - params.src = input; - params.skip = skip; - params.bias = bias; - params.gamma = gamma; - params.beta = beta; - params.group_sum_buffer = reinterpret_cast(workspace); - params.n = batch_size; - params.h = height; - params.w = width; - params.c = num_channels; - params.groups = num_groups; - params.hw = params.h * params.w; - - // This will allocate as many blocks as possible to partition HW. - // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. - // TODO: tune this logic to find proper blocks when hw is small. - constexpr int32_t max_blocks_per_hw = 1024; - const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); - params.hw_per_block = DivUp(params.hw, blocks_per_hw); - - params.channels_per_block = channels_per_block; - params.channels_per_group = channels_per_group; - params.hwc = params.hw * params.c; - params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); - params.groups_per_block = channels_per_block / params.channels_per_group; - params.epsilon = epsilon; - params.broadcast_skip = broadcast_skip; - - // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. - params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; - - params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; - CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh new file mode 100644 index 0000000000000..081e9a3de578c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -0,0 +1,355 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sum_sq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The first activation loaded by that block. + int32_t hw_begin = blockIdx.y * params.hw_per_block; + // The last activation loaded by that block. + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + // The sums. + float sum = 0.F; + float sum_sq = 0.F; + + // Iterate over the activations to compute the sums. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); + } + } + + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; + + // Do the segmented scan. InclusiveScan is not deterministic. + GroupSums out; + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groups_per_block) { + return; + } + + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } +} + +template +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); + +template <> +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sum_sq = 0.F; + if (gi < params.groups) { + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; + } + + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); + + // Compute the mean. + float mean = sum * params.inv_hw_channels_per_group; + // Compute the variance. + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + // Compute the inverse of the stddev. + float inv_std_dev = rsqrtf(var + params.epsilon); + + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime From 5b065050734e6bc397dc38ba0df246aeb57ac508 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Fri, 26 Jan 2024 00:25:35 +0800 Subject: [PATCH 05/11] [js/webgpu] Fix Tanh explosion (#19201) ### Description ```math \tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}= \left\{ \begin{array}{cc} -\frac{1-e^{-2\cdot(-x)}}{1+e^{-2\cdot(-x)}}, & x<0 \\ 0, & x=0 \\ \frac{1-e^{-2x}}{1+e^{-2x}}, & x>0 \end{array} \right. ``` ### Motivation and Context On some platforms, $$\tanh(1000)=\frac{e^{1000}-e^{-1000}}{e^{1000}+e^{-1000}}$$ would produce NaN instead of 0.999... or 1 (imagine $e^{1000}=\infty$ and $\frac{\infty}{\infty}$ explodes). --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 4 +++- js/web/test/data/ops/tanh.jsonc | 26 +++++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 js/web/test/data/ops/tanh.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 82311d72e58b9..76929efb32537 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => { }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 0000000000000..f7691535bd71c --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 373b3c645df57..56db28b0a379c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1389,6 +1389,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", From 2b285cd78a629971a9e465036e94a431e6fef17b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 09:30:15 -0800 Subject: [PATCH 06/11] [CUDA] Add functions to dump bfloat16 tensors (#19266) ### Description GroupQueryAttention add BFloat16 in https://github.com/microsoft/onnxruntime/pull/19095, and there is build error when enable dumping. This supports print bfloat16 tensor to console. --- .../cuda/transformers/dump_cuda_tensor.cc | 88 ++++++++++++------- .../cuda/transformers/dump_cuda_tensor.h | 27 ++++-- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index b31f5d243e001..4cfa89a4d58c2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -203,23 +203,19 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { @@ -227,9 +223,14 @@ void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { @@ -242,6 +243,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int d DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); +} + void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); @@ -252,22 +258,31 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { + Print(name, reinterpret_cast(tensor), dim0, dim1); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); } void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { @@ -301,43 +316,52 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { } +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +} + void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 264ecd7cfe2f5..773401f79531a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -16,20 +16,31 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons public: CudaTensorConsoleDumper() = default; virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; From a2867b911e67146218b4fc0b32721e5cdbade49b Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 25 Jan 2024 11:51:39 -0800 Subject: [PATCH 07/11] [TensorRT EP] Fix mem leak for TRT plugins custom ops (#19248) TRT EP's GetTensorRTCustomOpDomainList() will create vector of OrtCustomOpDomain objects and release the ownership of those objects. But, thoses objects are not released forever. In session level, we need to make TRT EP remember what OrtCustomOpDomain objects it created and release them at EP destruction time. --- .../tensorrt/tensorrt_execution_provider.cc | 18 +++++-- .../tensorrt_execution_provider_custom_ops.cc | 37 +++++--------- .../core/session/provider_bridge_ort.cc | 49 +++---------------- .../python/onnxruntime_pybind_state.cc | 6 +-- 4 files changed, 35 insertions(+), 75 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1834,13 +1834,21 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); + if (status != Status::OK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; } - custom_op_domain_list = info_.custom_op_domain_list; } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -27,8 +27,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain.get()); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,38 +73,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } return Status::OK(); } -common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { - std::vector domain_list; - std::string extra_plugin_lib_paths{""}; - if (info.has_trt_options) { - if (!info.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info.extra_plugin_lib_paths; - } - } else { - const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); - if (!extra_plugin_lib_paths_env.empty()) { - extra_plugin_lib_paths = extra_plugin_lib_paths_env; - } - } - auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); - if (!domain_list.empty()) { - info.custom_op_domain_list = domain_list; - } - return Status::OK(); -} - void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3178c13d30eec..f48110aa7ee5b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1713,17 +1713,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); - - return nullptr; + OrtTensorRTProviderOptionsV2 tensorrt_options; + tensorrt_options.device_id = device_id; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options); API_IMPL_END } @@ -1741,33 +1733,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - - std::shared_ptr factory; - -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) - auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; - // If EP context configs are provided in session options, we need to propagate them to provider options - if (ep_context_cache_enabled_from_sess_options) { - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); - } -#else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); -#endif - - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - AddTensorRTCustomOpDomainToSessionOption(options, ""); - - return nullptr; + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted); API_IMPL_END } @@ -1906,11 +1873,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { - // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. - // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); } else { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector domain_list; - tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); - for (auto ptr : domain_list) { + std::vector custom_op_domains; + tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { so.custom_op_domains_.push_back(ptr); } else { From 656ca66186c7fd362abd8f33915bd0f96483bf43 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 07:37:05 +0800 Subject: [PATCH 08/11] [js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753) --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 125 +++++++------ .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 154 ++++++++-------- .../ops/3rd-party/conv_backprop_webgpu.ts | 174 +++++++++++------- .../ops/3rd-party/matmul_packed_webgpu.ts | 108 +++++------ .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 86 +++++---- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 15 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 18 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 39 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 43 +++-- 9 files changed, 418 insertions(+), 344 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7dbe..1a03621512888 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -88,10 +88,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +129,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +141,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.dilations} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e14..33e50a9a39cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; - let declareFunctions = ''; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, + {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ @@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo = elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + undefined, sequentialAccessByThreads)}`; + }; + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50b0841a0200a..380efc8bc577a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,24 +20,18 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource = `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, + {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, + {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a296712..ee71110245252 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +204,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +380,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +392,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +402,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +423,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +441,67 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aShapeOrRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bShapeOrRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), + ...createTensorShapeVariables(bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchShapeOrRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21b4953d3f90c..f81d6577890c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramUniform} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, + {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const applyActivation = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; @@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const applyActivation = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo = .registerUniform('strides', 'i32', 2) .registerUniform('pads', 'i32', 2) .declareVariables(...inputVars, output)} - ${activationFunction} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo = return { name: 'GroupedConv-Vectorize', shaderCache: { - hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94ca..33d16754c737a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7af2c5db49f40..5afec0389fac8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 0b5c0db2b5112..2e0aa33a957dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -7,30 +7,21 @@ export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; - case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; - case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + // TODO: adding other activations that can be fused. + default: + return ''; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { @@ -38,7 +29,7 @@ export const parseInternalActivationAttributes = if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436f..c946ea6366123 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: 'uint32', data: K} ]; + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } }; From 8b4517218b52285efaaf8badd303c00b0e514238 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 16:57:58 -0800 Subject: [PATCH 09/11] Remove USE_CUTLASS flag (#19271) ### Description Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for onnxruntime CUDA build), there is no need to have a flag to disable cutlass. Changes: (1) Reverted https://github.com/microsoft/onnxruntime/pull/18761 (2) remove the condition to build cutlass. (3) Fix a few build errors or warnings during testing CUDA 11.4 build. Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later. Flash attention and cutlass fused multihead attention will not be built for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if you want to support latest GPUs. It is better to include it in 1.17.0 (otherwise, the release branch might encounter build failure with CUDA 11.4). Tests: (1) Build with flash attention and efficient attention off: **passed** (2) Build with CUDA 11.4: **passed** Example build command used in Ubuntu 20.04: ``` export CUDA_HOME=/usr/local/cuda-11.4 export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/ export CUDACXX=/usr/local/cuda-11.4/bin/nvcc sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 11.4 \ --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --disable_types float8 ``` ### Motivation and Context --- cmake/CMakeLists.txt | 23 ++++++------------- cmake/external/cutlass.cmake | 20 ++++++++-------- .../cuda/collective/sharded_moe.cc | 4 ---- .../contrib_ops/cuda/collective/sharded_moe.h | 4 ---- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 ------- .../cuda/moe/ft_moe/compute_occupancy.h | 5 ---- .../cuda/moe/ft_moe/cutlass_heuristic.cc | 11 ++++----- .../cuda/moe/ft_moe/cutlass_heuristic.h | 2 -- .../cuda/moe/ft_moe/epilogue_helpers.h | 4 ---- .../cuda/moe/ft_moe/ft_gemm_configs.h | 4 ---- .../moe/ft_moe/gemm_moe_problem_visitor.h | 4 ---- .../cuda/moe/ft_moe/layout_traits_helper.h | 6 +---- .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 4 ---- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 4 ---- .../moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu | 4 ---- .../moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu | 4 ---- .../moe/ft_moe/moe_gemm_kernels_template.h | 4 ---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 ---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 6 +---- .../cuda/moe/ft_moe/moe_problem_visitor.h | 4 ---- .../cuda/moe/ft_moe/tile_interleaved_layout.h | 5 ---- onnxruntime/contrib_ops/cuda/moe/moe.cc | 4 ---- onnxruntime/contrib_ops/cuda/moe/moe.h | 4 ---- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 4 ---- .../cuda/quantization/matmul_nbits.cu | 6 ++--- onnxruntime/test/contrib_ops/moe_test.cc | 4 ---- 26 files changed, 25 insertions(+), 131 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7d7304630c00e..0eb224623f678 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() - set(onnxruntime_USE_CUTLASS OFF) -endif() - -if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) - if (onnxruntime_DISABLE_CONTRIB_OPS) - message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") - else() - message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") - endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if (onnxruntime_USE_CUTLASS) - target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) - endif() - if(USE_NEURAL_SPEED) target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c0..f04f4bec76cd5 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +include(FetchContent) +FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 9b989dac9a94b..40a667ffd5d83 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -204,5 +202,3 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index cbd483fddab78..5ea4ae59c4020 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index fa73950c9c6f5..8f368251f12c7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -169,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -272,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -377,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 9b97690fe70fd..86136ea244e23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef USE_CUTLASS - #pragma once #include @@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index f0abd46572a90..adc043e5689e2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -64,5 +62,3 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d0dfe7c5a647..1d9a249db4237 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7a5d97902ee8f..7b250e6ca9060 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 3fd0fc47055a5..66950c9b65970 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,8 +14,6 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -428,5 +426,3 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 9232e8d012933..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include #include #include @@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index f09471de1cc2e..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "moe_gemm_kernels.h" @@ -174,6 +172,4 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 0da06192e266b..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -119,5 +117,3 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 710b914f0633d..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -26,5 +24,3 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index dc8b9d57f79f6..f55a7cde2e208 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "core/common/common.h" @@ -172,5 +170,3 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 67384957d8dd2..d4d583906b7f4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -89,7 +89,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +120,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +144,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, } #endif -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 844cc877f2568..ebb0261deefa5 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -423,5 +421,3 @@ TEST(MoETest, MoETest_Relu) { } // namespace test } // namespace onnxruntime - -#endif From a3f0e2422b5eb2968e3f11e93414aa1661b32e2f Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 08:58:22 +0800 Subject: [PATCH 10/11] [js/webgpu] Support f16 uniform (#19098) ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/backend-webgpu.ts | 26 +++++++++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 40 +++++++++++++------ js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 4 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 +- .../core/providers/js/operators/pad.cc | 10 ++--- 5 files changed, 56 insertions(+), 26 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 8ca025d66550c..a48fe99570abf 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -428,13 +428,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + const sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVecOrMat; + let baseAlignment; + if (v.type === 'float16') { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVecOrMat = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -449,6 +462,9 @@ export class WebGpuBackend { new Int32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === 'uint32') { new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'float16') { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); } else { new Float32Array(arrayBuffer, offset, data.length).set(data); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..643744108c0f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +698,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index eca3fa7d944bb..c65b741e1105a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index e55bfb6ba9f16..789ac70a6913a 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -24,7 +24,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: 'int32'|'float16'|'float32'|'uint32'; data: number|readonly number[]; } diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 24ba85cbf6e0d..83fee35481aa6 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 2, 10, kJsExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Pad); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 17, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( 19, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), From 358650d4415d930ba3ea4de159b8191cb1696dc4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 17:19:04 -0800 Subject: [PATCH 11/11] Fix BigModel stable diffusion pipeline (#19277) ### Description Fix two issues: (1) We can only use single quote inside `bash -c "..."`. Current pipeline job stopped at `python3 demo_txt2img.py astronaut` and skip the following commands. In this change, we remove the remaining commands to get same effect (otherwise, the pipeline runtime might be 2 hours instead of 15 minutes). (2) Fix a typo of Stable. --- .../github/azure-pipelines/bigmodels-ci-pipeline.yml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index ff2e7c0468a21..b767b7276b428 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -136,11 +136,11 @@ stages: - template: templates/explicitly-defined-final-tasks.yml -- stage: Stale_Diffusion +- stage: Stable_Diffusion dependsOn: - Build_Onnxruntime_Cuda jobs: - - job: Stale_Diffusion + - job: Stable_Diffusion variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -171,12 +171,7 @@ stages: python3 -m pip install -r requirements-cuda11.txt; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ echo Generate an image guided by a text prompt; \ - python3 demo_txt2img.py "astronaut riding a horse on mars"; \ - echo Generate an image with Stable Diffusion XL guided by a text prompt; \ - python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ - python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ - echo Generate an image guided by a text prompt using LCM LoRA; \ - python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + python3 demo_txt2img.py 'astronaut riding a horse on mars'; \ popd; \ " displayName: 'Run stable diffusion demo'