Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] SkipGroupNorm triton #19408

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Input skip or bias is not supported by triton kernel.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
Expand All @@ -61,23 +59,36 @@
}
// Construct args for launch kernel
struct {
void* X;
void* Y;
const void* src;
const void* skip;
const void* bias;
void* out;
void* add_out;
const void* gamma;
const void* beta;
int hw;
int c;
int c_per_group;
float eps;
bool has_skip;
bool has_bias;
bool broadcast_skip;
} args = {
(void*)params->src,
(const void*)params->src,
(const void*)params->skip,
(const void*)params->bias,
(void*)params->dst,
(void*)params->skip_workspace,

Check warning on line 81 in onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh:81: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
params->c,
params->channels_per_group,
params->epsilon};
params->epsilon,
params->skip != nullptr,
params->bias != nullptr,
params->broadcast_skip,
};

// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
Expand Down
39 changes: 35 additions & 4 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
@triton.jit
def group_norm_kernel(
input_ptr,
skip_ptr,
bias_ptr,
output_ptr,
add_out_ptr,
gamma_ptr,
beta_ptr,
img_size,
c,
c_per_group,
eps,
has_skip,
has_bias,
broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
Expand All @@ -36,14 +42,35 @@ def group_norm_kernel(
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]

bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if has_skip:
add_out_ptr += row_x * stride + row_y * c_per_group
if broadcast_skip:
broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
else:
skip_ptr += row_x * stride + row_y * c_per_group
if has_bias:
bias_ptr += row_y * c_per_group
bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)

# Calculate mean and variance
_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
_square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip and not broadcast_skip:
s_ptr = skip_ptr + i * HW_SIZE * c
s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a += s
if has_bias or broadcast_skip:
a += bias
_sum += a
_square_sum += a * a
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
tl.store(add_y_ptr + offsets, a, mask=mask)

# Set axis=None (or leave it unspecified) to reduce all axes.
# TODO: In older Triton we have to reduce an axis at a time, but in our case
Expand All @@ -57,9 +84,13 @@ def group_norm_kernel(
gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
y_ptr = output_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
else:
x_ptr = input_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SILU:
Expand All @@ -77,7 +108,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"


Expand All @@ -88,7 +119,7 @@ def get_function_table():
silu_suffix = "Silu" if silu else "Pass"
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(silu_suffix, dtype)
sig = sig_pattern.format(dtype, dtype)
sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def run_group_norm(
)
use_silu = silu
broadcast_skip = False
if has_skip:
skip_x_shape = skip_x.shape
b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels
b4 = (
len(skip_x_shape) == 4
and skip_x_shape[0] == batch_size
and skip_x_shape[1] == 1
and skip_x_shape[2] == 1
and skip_x_shape[3] == num_channels
)
if b2 or b4:
broadcast_skip = True
channels_per_block = 0 # Compute in params initialization

input_d = ke.DeviceArray(input_x.astype(dtype))
Expand Down
Loading