Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Feb 22, 2024
2 parents 01866c2 + 05ed89f commit 89ebc2a
Show file tree
Hide file tree
Showing 18 changed files with 131 additions and 32 deletions.
15 changes: 13 additions & 2 deletions java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand All @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags {
/** Enables CoreML on subgraphs. */
ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002)
/** Only enable usage of CoreML if the device has an Apple Neural Engine. */
ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004),
ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004)
/**
* Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also
* allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs
* have dynamic shapes.
*/
ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008)
/**
* Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or
* later.
*/
CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010)

/** The native value of the enum. */
public final int value;
Expand Down
6 changes: 3 additions & 3 deletions js/react_native/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3701,9 +3701,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"

ip@^1.1.5:
version "1.1.8"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
version "1.1.9"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==

is-absolute@^1.0.0:
version "1.0.0"
Expand Down
4 changes: 2 additions & 2 deletions js/web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f

With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.

ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.

See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.

Expand All @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun

## Documents

### Developement
### Development

Refer to the following links for development information:

Expand Down
23 changes: 17 additions & 6 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Input skip or bias is not supported by triton kernel.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
Expand All @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
}
// Construct args for launch kernel
struct {
void* X;
void* Y;
const void* src;
const void* skip;
const void* bias;
void* out;
void* add_out;
const void* gamma;
const void* beta;
int hw;
int c;
int c_per_group;
float eps;
bool has_skip;
bool has_bias;
bool broadcast_skip;
} args = {
(void*)params->src,
(const void*)params->src,
(const void*)params->skip,
(const void*)params->bias,
(void*)params->dst,
(void*)params->skip_workspace,
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
params->c,
params->channels_per_group,
params->epsilon};
params->epsilon,
params->skip != nullptr,
params->bias != nullptr,
params->broadcast_skip,
};

// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
Expand Down
39 changes: 35 additions & 4 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
@triton.jit
def group_norm_kernel(
input_ptr,
skip_ptr,
bias_ptr,
output_ptr,
add_out_ptr,
gamma_ptr,
beta_ptr,
img_size,
c,
c_per_group,
eps,
has_skip,
has_bias,
broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
Expand All @@ -36,14 +42,35 @@ def group_norm_kernel(
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]

bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if has_skip:
add_out_ptr += row_x * stride + row_y * c_per_group
if broadcast_skip:
broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
else:
skip_ptr += row_x * stride + row_y * c_per_group
if has_bias:
bias_ptr += row_y * c_per_group
bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)

# Calculate mean and variance
_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
_square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip and not broadcast_skip:
s_ptr = skip_ptr + i * HW_SIZE * c
s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a += s
if has_bias or broadcast_skip:
a += bias
_sum += a
_square_sum += a * a
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
tl.store(add_y_ptr + offsets, a, mask=mask)

# Set axis=None (or leave it unspecified) to reduce all axes.
# TODO: In older Triton we have to reduce an axis at a time, but in our case
Expand All @@ -57,9 +84,13 @@ def group_norm_kernel(
gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
y_ptr = output_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
else:
x_ptr = input_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SILU:
Expand All @@ -77,7 +108,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"


Expand All @@ -88,7 +119,7 @@ def get_function_table():
silu_suffix = "Silu" if silu else "Pass"
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(silu_suffix, dtype)
sig = sig_pattern.format(dtype, dtype)
sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
Expand Down
17 changes: 14 additions & 3 deletions onnxruntime/core/providers/cuda/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group),
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
UseTF32()));

if (context->InputCount() >= 3) {
const Tensor* B = context->Input<Tensor>(2);
Expand All @@ -351,8 +352,13 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
// set math type to tensor core before algorithm search
if constexpr (std::is_same<T, MLFloat16>::value)
if constexpr (std::is_same<T, MLFloat16>::value) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
} else if constexpr (std::is_same<T, float>::value) {
if (!UseTF32()) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
}
}

cudnnConvolutionFwdAlgoPerf_t perf;
int algo_count = 1;
Expand Down Expand Up @@ -399,6 +405,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
if (std::is_same<T, MLFloat16>::value) {
perf.mathType = CUDNN_TENSOR_OP_MATH;
} else if (std::is_same<T, float>::value && !UseTF32()) {
perf.mathType = CUDNN_FMA_MATH;
} else {
perf.mathType = CUDNN_DEFAULT_MATH;
}
Expand Down Expand Up @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set(
const gsl::span<const int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type) {
cudnnDataType_t data_type,
bool use_tf32) {
if (!desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));

Expand Down Expand Up @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set(
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
if (data_type == CUDNN_DATA_HALF) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
} else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
}

return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final {
const gsl::span<const int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type);
cudnnDataType_t data_type,
bool use_tf32);

operator cudnnConvolutionDescriptor_t() const { return desc_; }

Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
gsl::narrow_cast<int>(conv_transpose_attrs_.group), mode,
CudnnTensor::GetDataType<CudaT>()));
CudnnTensor::GetDataType<CudaT>(),
UseTF32()));

if (has_bias) {
const auto& b_shape = p.B->Shape();
Expand All @@ -187,8 +188,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());

// set math type to tensor core before algorithm search
if constexpr (std::is_same<T, MLFloat16>::value)
if constexpr (std::is_same<T, MLFloat16>::value) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
} else if constexpr (std::is_same<T, float>::value) {
if (!UseTF32()) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
}
}

cudnnConvolutionBwdDataAlgoPerf_t perf;
int algo_count = 1;
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def run_group_norm(
)
use_silu = silu
broadcast_skip = False
if has_skip:
skip_x_shape = skip_x.shape
b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels
b4 = (
len(skip_x_shape) == 4
and skip_x_shape[0] == batch_size
and skip_x_shape[1] == 1
and skip_x_shape[2] == 1
and skip_x_shape[3] == num_channels
)
if b2 or b4:
broadcast_skip = True
channels_per_block = 0 # Compute in params initialization

input_d = ke.DeviceArray(input_x.astype(dtype))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first)));
} else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first)));
} else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str()));
} else {
ORT_THROW("Unsupported kwargs data type: ", kv.second.second);
}
Expand Down
3 changes: 2 additions & 1 deletion orttraining/orttraining/python/training/ort_triton/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl


def next_power_of_2(n: int) -> int:
assert n <= 2**32, "32-bit only"
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n

Expand Down
5 changes: 4 additions & 1 deletion orttraining/orttraining/training_ops/cpu/triton/triton_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel {
attr.first == "onnx_string") {
continue;
}
// Support int64 and float only for now, skip other types.
// Support int64, float and string only for now, skip other types.
if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}});
} else if (attr.second.type() ==
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}});
} else if (attr.second.type() ==
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) {
kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}});
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor&
ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
args_.params.data_type));
args_.params.data_type,
UseTF32()));

if (dB) {
const TensorShape& db_shape = dB->Shape();
Expand Down
6 changes: 4 additions & 2 deletions orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
}

template <typename T_Perf>
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
perf_results.resize(1);
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
if (args.params.data_type == CUDNN_DATA_HALF) {
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
} else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
perf_results[0].mathType = CUDNN_FMA_MATH;
} else {
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
}
Expand All @@ -256,7 +258,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const

std::vector<T_Perf> perf_results;
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
? OnlyDefaultAlgorithm(args_, perf_results)
? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
for (auto& algo_perf : perf_results) {
if (f(algo_perf) == Status::OK()) {
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AlgoIterator {
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
std::function<Status(const T_Perf& perf)> f);

static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);

private:
const ConvArgs& args_;
Expand Down
Loading

0 comments on commit 89ebc2a

Please sign in to comment.