From 253b66a5fb69ba0eafe5070843d25034a79739d8 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Mon, 19 Feb 2024 09:19:59 +0000 Subject: [PATCH] add broadcast back --- .../rocm/diffusion/group_norm_triton.cuh | 5 ++--- .../contrib_ops/rocm/diffusion/group_norm_triton.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 4a2e2ae57307d..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,9 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->skip != nullptr && params->broadcast_skip, "Arg broadcast_skip is not supported."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -75,6 +72,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { float eps; bool has_skip; bool has_bias; + bool broadcast_skip; } args = { (const void*)params->src, (const void*)params->skip, @@ -89,6 +87,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon, params->skip != nullptr, params->bias != nullptr, + params->broadcast_skip, }; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 62b9d624340ff..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -24,6 +24,7 @@ def group_norm_kernel( eps, has_skip, has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -44,7 +45,11 @@ def group_norm_kernel( bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) if has_skip: add_out_ptr += row_x * stride + row_y * c_per_group - skip_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group if has_bias: bias_ptr += row_y * c_per_group bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) @@ -55,11 +60,11 @@ def group_norm_kernel( for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip: + if has_skip and not broadcast_skip: s_ptr = skip_ptr + i * HW_SIZE * c s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) a += s - if has_bias: + if has_bias or broadcast_skip: a += bias _sum += a _square_sum += a * a @@ -103,7 +108,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}"