Skip to content

Commit

Permalink
fix line_length
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 4, 2024
1 parent 1e0f537 commit 1421c7a
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 57 deletions.
9 changes: 6 additions & 3 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
auto invoker = impl->MakeInvokerPointer();

auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), "Skip is not supported");
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](
const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Skip is not supported");
if constexpr (WithSwish) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->use_silu, "Swish version only support groupnorm with swish");
Expand All @@ -59,7 +61,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
params->use_silu, "Pass version only support groupnorm without swish");
}
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->channels_per_group, 1};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c,
params->c, params->channels_per_group, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->channels_per_group, 1};
std::vector<ck::index_t> reduce_dims{1, 2, 4};

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams<T> {

std::string Signature() const override {
std::string swish_suffix = this->use_silu ? "_silu" : "_pass";
std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + std::to_string(this->c) + "_" + std::to_string(this->groups) + swish_suffix;
std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" +
std::to_string(this->c) + "_" + std::to_string(this->groups) + swish_suffix;
return sig;
}
};
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ Status LaunchGroupNormKernel(
bool broadcast_skip,
int channels_per_block) {
GroupNormNHWCTunableParams<T> params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta,
reinterpret_cast<float*>(workspace), epsilon, batch_size, num_channels, height, width,
num_groups, use_silu, broadcast_skip, channels_per_block);
reinterpret_cast<float*>(workspace), epsilon, batch_size, num_channels,
height, width, num_groups, use_silu, broadcast_skip, channels_per_block);

if (params.channels_per_block % params.channels_per_group != 0 ||
params.channels_per_block > kMaxSize ||
Expand All @@ -59,17 +59,17 @@ Status LaunchGroupNormKernel(
return GroupNormNHWCStaticSelection(&params);
}

template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out,
const half* input, const half* skip, const half* bias, const float* gamma,
const float* beta, void* workspace, float epsilon, int batch_size,
int num_channels, int height, int width, int num_groups, bool use_silu,
bool broadcast_skip, int channels_per_block);
template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, Stream* stream, half* output,
half* add_out, const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace, float epsilon,
int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block);

template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out,
const float* input, const float* skip, const float* bias, const float* gamma,
const float* beta, void* workspace, float epsilon, int batch_size,
int num_channels, int height, int width, int num_groups, bool use_silu,
bool broadcast_skip, int channels_per_block);
template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, Stream* stream, float* output,
float* add_out, const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace, float epsilon,
int batch_size, int num_channels, int height, int width, int num_groups,
bool use_silu, bool broadcast_skip, int channels_per_block);

} // namespace rocm
} // namespace contrib
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), "Skip is not supported");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Skip is not supported");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", params->channels_per_group, ").");
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
params->channels_per_group, ").");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
if constexpr (WithSwish) {
Expand Down
26 changes: 15 additions & 11 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ void groupNormNHWCScale(const GroupNormNHWCTunableParams<T>* params) {
// The number of instances.
grid.z = params->n;

#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
GroupNormNHWCScaleKernel<T, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, \
params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, \
params->use_silu); \
#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
GroupNormNHWCScaleKernel<T, VecSize> \
<<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>( \
params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \
params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \
params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \
params->hw, params->hw_per_block, params->use_silu); \
break;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
Expand Down Expand Up @@ -127,8 +127,10 @@ template <typename T, int ThreadsPerBlock, int VecSize>
class GroupNormNHWCOp {
public:
Status operator()(const GroupNormNHWCTunableParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(
params->group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), params->StreamHandle()));
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
0,
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
params->StreamHandle()));
auto status = GroupNormNHWCSumOp<T, ThreadsPerBlock, VecSize>(params);
ORT_RETURN_IF_ERROR(status);
HIP_RETURN_IF_ERROR(hipGetLastError());
Expand All @@ -155,8 +157,10 @@ class GroupNormNHWCOp {

template <typename T>
Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams<T>* params) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(
params->group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), params->StreamHandle()));
HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
0,
GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
params->StreamHandle()));
groupNormNHWCSum<T>(params);
HIP_RETURN_IF_ERROR(hipGetLastError());
groupNormNHWCScale<T>(params);
Expand Down
Loading

0 comments on commit 1421c7a

Please sign in to comment.