Skip to content

Commit

Permalink
Adjust profiler space for SM89 (#1553)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlei-bao authored Sep 19, 2024
1 parent 2991ce1 commit 44dae8b
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4881,7 +4881,8 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
return

layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor)
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor)
]

math_instructions = [
Expand Down Expand Up @@ -4935,43 +4936,49 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):

for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
]

data_types = [
Expand All @@ -4981,6 +4988,12 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
DataType.f32,
math_inst.element_accumulator
],
[
math_inst.element_a,
math_inst.element_b,
DataType.bf16,
math_inst.element_accumulator
],
]

operations = []
Expand Down

0 comments on commit 44dae8b

Please sign in to comment.