Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 19, 2024
1 parent d730730 commit 8ae2948
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->skip != nullptr && params->broadcast_skip, "Arg broadcast_skip is not supported.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
Expand All @@ -72,7 +75,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
float eps;
bool has_skip;
bool has_bias;
bool broadcast_skip;
} args = {
(const void*)params->src,
(const void*)params->skip,
Expand All @@ -87,7 +89,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
params->epsilon,
params->skip != nullptr,
params->bias != nullptr,
params->broadcast_skip,
};

// Grid dim is (batch_count, groups, 1)
Expand Down
20 changes: 7 additions & 13 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def group_norm_kernel(
eps,
has_skip,
has_bias,
broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
Expand All @@ -37,20 +36,15 @@ def group_norm_kernel(
gamma_ptr += row_y * c_per_group
beta_ptr += row_y * c_per_group

add_out_ptr += row_x * stride + row_y * c_per_group

cols = tl.arange(0, BLOCK_SIZE)
hw = tl.arange(0, HW_SIZE)
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]

bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if has_skip:
if broadcast_skip:
broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
else:
skip_ptr += row_x * stride + row_y * c_per_group
add_out_ptr += row_x * stride + row_y * c_per_group
skip_ptr += row_x * stride + row_y * c_per_group
if has_bias:
bias_ptr += row_y * c_per_group
bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
Expand All @@ -61,11 +55,11 @@ def group_norm_kernel(
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip and not broadcast_skip:
if has_skip:
s_ptr = skip_ptr + i * HW_SIZE * c
s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a += s
if has_bias or broadcast_skip:
if has_bias:
a += bias
_sum += a
_square_sum += a * a
Expand Down Expand Up @@ -105,11 +99,11 @@ def group_norm_kernel(
# but this will result in too many functions and slow down the compilation.
with_silu = [True, False]
dtypes = ["fp32", "fp16"]
blocks = [16, 32, 64, 128]
blocks = [16, 32, 64]
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
warps = [1, 2, 4, 8]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"


Expand Down

0 comments on commit 8ae2948

Please sign in to comment.