Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 4, 2024
1 parent 1d98750 commit a88e084
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 0 additions & 3 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ def group_norm_kernel(
gamma_ptr += row_y * c_per_group
beta_ptr += row_y * c_per_group


add_out_ptr += row_x * stride + row_y * c_per_group



cols = tl.arange(0, BLOCK_SIZE)
hw = tl.arange(0, HW_SIZE)
offsets = hw[:, None] * c + cols[None, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,19 @@ def run_group_norm(
)
use_swish = swish
broadcast_skip = False
if(has_skip):
if has_skip:
skip_x_shape = skip_x.shape
b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels
b4 = len(skip_x_shape) == 4 and skip_x_shape[0] == batch_size and skip_x_shape[1] == 1 and skip_x_shape[2] == 1 and skip_x_shape[3] == num_channels
b4 = (
len(skip_x_shape) == 4
and skip_x_shape[0] == batch_size
and skip_x_shape[1] == 1
and skip_x_shape[2] == 1
and skip_x_shape[3] == num_channels
)
if b2 or b4:
broadcast_skip = True
channels_per_block = 0 # Compute in params initialization
channels_per_block = 0 # Compute in params initialization

input_d = ke.DeviceArray(input_x.astype(dtype))
skip_d = ke.DeviceArray(skip_x.astype(dtype))
Expand Down

0 comments on commit a88e084

Please sign in to comment.