From 05f866533e917144d76384d61a5b87ee41ed7a10 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Mon, 22 Jan 2024 09:02:09 +0000 Subject: [PATCH 1/5] add skip and bias on groupnorm triton implementation --- .../rocm/diffusion/group_norm_triton.cuh | 23 +++++++--- .../rocm/diffusion/group_norm_triton.py | 43 +++++++++++++++++-- .../kernel_explorer/kernels/groupnorm_test.py | 8 +++- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b3d3e92209b39..64e4c1e05f423 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->add_out, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, params->channels_per_group, - params->epsilon}; + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 5368cb1cf635b..d714408dbc05b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,13 +12,19 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -31,19 +37,44 @@ def group_norm_kernel( gamma_ptr += row_y * c_per_group beta_ptr += row_y * c_per_group + + add_out_ptr += row_x * stride + row_y * c_per_group + + + cols = tl.arange(0, BLOCK_SIZE) hw = tl.arange(0, HW_SIZE) offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,9 +88,13 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta if ACTIVATION_SILU: @@ -77,7 +112,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" @@ -88,7 +123,7 @@ def get_function_table(): silu_suffix = "Silu" if silu else "Pass" name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 8334d20e47c86..6aa28fecb3260 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -80,7 +80,13 @@ def run_group_norm( ) use_silu = silu broadcast_skip = False - channels_per_block = 0 # Compute in params initialization + if(has_skip): + skip_x_shape = skip_x.shape + b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels + b4 = len(skip_x_shape) == 4 and skip_x_shape[0] == batch_size and skip_x_shape[1] == 1 and skip_x_shape[2] == 1 and skip_x_shape[3] == num_channels + if b2 or b4: + broadcast_skip = True + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x.astype(dtype)) skip_d = ke.DeviceArray(skip_x.astype(dtype)) From b1b055bd9a4e4793342217dd7dcbabc343521930 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Sun, 4 Feb 2024 09:09:42 +0000 Subject: [PATCH 2/5] fix format --- .../contrib_ops/rocm/diffusion/group_norm_triton.py | 3 --- .../tools/kernel_explorer/kernels/groupnorm_test.py | 12 +++++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index d714408dbc05b..8e20aa07bd8ab 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -37,11 +37,8 @@ def group_norm_kernel( gamma_ptr += row_y * c_per_group beta_ptr += row_y * c_per_group - add_out_ptr += row_x * stride + row_y * c_per_group - - cols = tl.arange(0, BLOCK_SIZE) hw = tl.arange(0, HW_SIZE) offsets = hw[:, None] * c + cols[None, :] diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 6aa28fecb3260..400a9d8a7a187 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -80,13 +80,19 @@ def run_group_norm( ) use_silu = silu broadcast_skip = False - if(has_skip): + if has_skip: skip_x_shape = skip_x.shape b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels - b4 = len(skip_x_shape) == 4 and skip_x_shape[0] == batch_size and skip_x_shape[1] == 1 and skip_x_shape[2] == 1 and skip_x_shape[3] == num_channels + b4 = ( + len(skip_x_shape) == 4 + and skip_x_shape[0] == batch_size + and skip_x_shape[1] == 1 + and skip_x_shape[2] == 1 + and skip_x_shape[3] == num_channels + ) if b2 or b4: broadcast_skip = True - channels_per_block = 0 # Compute in params initialization + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x.astype(dtype)) skip_d = ke.DeviceArray(skip_x.astype(dtype)) From 240db1443f02626364ca92ebfe66f4227e0a38ea Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Sun, 18 Feb 2024 06:03:53 +0000 Subject: [PATCH 3/5] update --- .../rocm/diffusion/group_norm_triton.cuh | 5 +++-- .../rocm/diffusion/group_norm_triton.py | 20 +++++++------------ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 64e4c1e05f423..031e76f1e53a6 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,6 +46,9 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->skip != nullptr && params->broadcast_skip, "Arg broadcast_skip is not supported."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -72,7 +75,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { float eps; bool has_skip; bool has_bias; - bool broadcast_skip; } args = { (const void*)params->src, (const void*)params->skip, @@ -87,7 +89,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon, params->skip != nullptr, params->bias != nullptr, - params->broadcast_skip, }; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 8e20aa07bd8ab..026c08586f076 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -24,7 +24,6 @@ def group_norm_kernel( eps, has_skip, has_bias, - broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -37,8 +36,6 @@ def group_norm_kernel( gamma_ptr += row_y * c_per_group beta_ptr += row_y * c_per_group - add_out_ptr += row_x * stride + row_y * c_per_group - cols = tl.arange(0, BLOCK_SIZE) hw = tl.arange(0, HW_SIZE) offsets = hw[:, None] * c + cols[None, :] @@ -46,11 +43,8 @@ def group_norm_kernel( bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) if has_skip: - if broadcast_skip: - broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group - bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - else: - skip_ptr += row_x * stride + row_y * c_per_group + add_out_ptr += row_x * stride + row_y * c_per_group + skip_ptr += row_x * stride + row_y * c_per_group if has_bias: bias_ptr += row_y * c_per_group bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) @@ -61,11 +55,11 @@ def group_norm_kernel( for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip and not broadcast_skip: + if has_skip: s_ptr = skip_ptr + i * HW_SIZE * c s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) a += s - if has_bias or broadcast_skip: + if has_bias: a += bias _sum += a _square_sum += a * a @@ -105,11 +99,11 @@ def group_norm_kernel( # but this will result in too many functions and slow down the compilation. with_silu = [True, False] dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64, 128] +blocks = [16, 32, 64] hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8, 16] +warps = [1, 2, 4, 8] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1" group_pattern = "GroupNormTriton_{}_{}" From 2f0df6937d5f2bb69163df79625974f92dd5589b Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Mon, 19 Feb 2024 08:31:20 +0000 Subject: [PATCH 4/5] model work --- onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh | 2 +- onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 031e76f1e53a6..4a2e2ae57307d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -80,7 +80,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { (const void*)params->skip, (const void*)params->bias, (void*)params->dst, - (void*)params->add_out, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 026c08586f076..62b9d624340ff 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -99,9 +99,9 @@ def group_norm_kernel( # but this will result in too many functions and slow down the compilation. with_silu = [True, False] dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64] +blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8] +warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1" group_pattern = "GroupNormTriton_{}_{}" From 4a6f2efdda26e91144bb7a8d8d506946436b7675 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Mon, 19 Feb 2024 09:19:59 +0000 Subject: [PATCH 5/5] add broadcast back --- .../rocm/diffusion/group_norm_triton.cuh | 5 ++--- .../contrib_ops/rocm/diffusion/group_norm_triton.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 4a2e2ae57307d..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,9 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->skip != nullptr && params->broadcast_skip, "Arg broadcast_skip is not supported."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -75,6 +72,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { float eps; bool has_skip; bool has_bias; + bool broadcast_skip; } args = { (const void*)params->src, (const void*)params->skip, @@ -89,6 +87,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon, params->skip != nullptr, params->bias != nullptr, + params->broadcast_skip, }; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 62b9d624340ff..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -24,6 +24,7 @@ def group_norm_kernel( eps, has_skip, has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -44,7 +45,11 @@ def group_norm_kernel( bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) if has_skip: add_out_ptr += row_x * stride + row_y * c_per_group - skip_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group if has_bias: bias_ptr += row_y * c_per_group bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) @@ -55,11 +60,11 @@ def group_norm_kernel( for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip: + if has_skip and not broadcast_skip: s_ptr = skip_ptr + i * HW_SIZE * c s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) a += s - if has_bias: + if has_bias or broadcast_skip: a += bias _sum += a _square_sum += a * a @@ -103,7 +108,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}"