Skip to content

Commit

Permalink
model work
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 19, 2024
1 parent 8ae2948 commit 21ee9ae
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
(const void*)params->skip,
(const void*)params->bias,
(void*)params->dst,
(void*)params->add_out,
(void*)params->skip_workspace,
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def group_norm_kernel(
# but this will result in too many functions and slow down the compilation.
with_silu = [True, False]
dtypes = ["fp32", "fp16"]
blocks = [16, 32, 64]
blocks = [16, 32, 64, 128]
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"
Expand Down

0 comments on commit 21ee9ae

Please sign in to comment.