Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Peixuan Zuo committed Jan 26, 2024
1 parent 0145d8c commit 0c095e8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ struct GroupNormNHWCParams {
const T* bias,
const float* gamma,
const float* beta,
void* workspace,
float* workspace,
float epsilon,
int batch_size,
int num_channels,
Expand All @@ -151,7 +151,7 @@ struct GroupNormNHWCParams {
this->bias = bias;
this->gamma = gamma;
this->beta = beta;
this->group_sum_buffer = reinterpret_cast<float*>(workspace);
this->group_sum_buffer = workspace;
this->n = batch_size;
this->h = height;
this->w = width;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Status LaunchGroupNormKernel(
// tuning_ctx only used for ROCm EP.
ORT_UNUSED_PARAMETER(tuning_ctx);

GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon,
GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast<float*>(workspace), epsilon,
batch_size, num_channels, height, width, num_groups, use_silu,
broadcast_skip, channels_per_block);

Expand Down

0 comments on commit 0c095e8

Please sign in to comment.