diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index c736551432..9f327154a6 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4881,7 +4881,8 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): return layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor) + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) ] math_instructions = [ @@ -4935,43 +4936,49 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): for math_inst in math_instructions: tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] data_types = [ @@ -4981,6 +4988,12 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): DataType.f32, math_inst.element_accumulator ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], ] operations = []