Skip to content

Commit

Permalink
add broadcast back
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 22, 2024
1 parent 2f0df69 commit 4a6f2ef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
5 changes: 2 additions & 3 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->skip != nullptr && params->broadcast_skip, "Arg broadcast_skip is not supported.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
Expand All @@ -75,6 +72,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
float eps;
bool has_skip;
bool has_bias;
bool broadcast_skip;
} args = {
(const void*)params->src,
(const void*)params->skip,
Expand All @@ -89,6 +87,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
params->epsilon,
params->skip != nullptr,
params->bias != nullptr,
params->broadcast_skip,
};

// Grid dim is (batch_count, groups, 1)
Expand Down
13 changes: 9 additions & 4 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def group_norm_kernel(
eps,
has_skip,
has_bias,
broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
Expand All @@ -44,7 +45,11 @@ def group_norm_kernel(
bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if has_skip:
add_out_ptr += row_x * stride + row_y * c_per_group
skip_ptr += row_x * stride + row_y * c_per_group
if broadcast_skip:
broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
else:
skip_ptr += row_x * stride + row_y * c_per_group
if has_bias:
bias_ptr += row_y * c_per_group
bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
Expand All @@ -55,11 +60,11 @@ def group_norm_kernel(
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip:
if has_skip and not broadcast_skip:
s_ptr = skip_ptr + i * HW_SIZE * c
s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a += s
if has_bias:
if has_bias or broadcast_skip:
a += bias
_sum += a
_square_sum += a * a
Expand Down Expand Up @@ -103,7 +108,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"


Expand Down

0 comments on commit 4a6f2ef

Please sign in to comment.