Skip to content

Commit

Permalink
add skip and bias on groupnorm triton implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 4, 2024
1 parent 1421c7a commit 1d98750
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 11 deletions.
23 changes: 17 additions & 6 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Skip is not supported");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
Expand All @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
}
// Construct args for launch kernel
struct {
void* X;
void* Y;
const void* src;
const void* skip;
const void* bias;
void* out;
void* add_out;
const void* gamma;
const void* beta;
int hw;
int c;
int c_per_group;
float eps;
bool has_skip;
bool has_bias;
bool broadcast_skip;
} args = {
(void*)params->src,
(const void*)params->src,
(const void*)params->skip,
(const void*)params->bias,
(void*)params->dst,
(void*)params->add_out,

Check warning on line 81 in onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh:81: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
params->c,
params->channels_per_group,
params->epsilon};
params->epsilon,
params->skip != nullptr,
params->bias != nullptr,
params->broadcast_skip,
};

// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
Expand Down
43 changes: 39 additions & 4 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
@triton.jit
def group_norm_kernel(
input_ptr,
skip_ptr,
bias_ptr,
output_ptr,
add_out_ptr,
gamma_ptr,
beta_ptr,
img_size,
c,
c_per_group,
eps,
has_skip,
has_bias,
broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SWISH: tl.constexpr,
Expand All @@ -31,19 +37,44 @@ def group_norm_kernel(
gamma_ptr += row_y * c_per_group
beta_ptr += row_y * c_per_group


add_out_ptr += row_x * stride + row_y * c_per_group



cols = tl.arange(0, BLOCK_SIZE)
hw = tl.arange(0, HW_SIZE)
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]

bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if has_skip:
if broadcast_skip:
broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
else:
skip_ptr += row_x * stride + row_y * c_per_group
if has_bias:
bias_ptr += row_y * c_per_group
bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)

# Calculate mean and variance
_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
_square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip and not broadcast_skip:
s_ptr = skip_ptr + i * HW_SIZE * c
s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a += s
if has_bias or broadcast_skip:
a += bias
_sum += a
_square_sum += a * a
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
tl.store(add_y_ptr + offsets, a, mask=mask)

# Set axis=None (or leave it unspecified) to reduce all axes.
# TODO: In older Triton we have to reduce an axis at a time, but in our case
Expand All @@ -57,9 +88,13 @@ def group_norm_kernel(
gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
y_ptr = output_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
if has_skip:
add_y_ptr = add_out_ptr + i * HW_SIZE * c
x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
else:
x_ptr = input_ptr + i * HW_SIZE * c
x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SWISH:
Expand All @@ -77,7 +112,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"


Expand All @@ -88,7 +123,7 @@ def get_function_table():
swish_suffix = "Swish" if swish else "Pass"
name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(swish_suffix, dtype)
sig = sig_pattern.format(dtype, dtype)
sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def run_group_norm(
)
use_swish = swish
broadcast_skip = False
channels_per_block = 0 # Compute in params initialization
if(has_skip):
skip_x_shape = skip_x.shape
b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels
b4 = len(skip_x_shape) == 4 and skip_x_shape[0] == batch_size and skip_x_shape[1] == 1 and skip_x_shape[2] == 1 and skip_x_shape[3] == num_channels
if b2 or b4:
broadcast_skip = True
channels_per_block = 0 # Compute in params initialization

input_d = ke.DeviceArray(input_x.astype(dtype))
skip_d = ke.DeviceArray(skip_x.astype(dtype))
Expand Down

0 comments on commit 1d98750

Please sign in to comment.