From 19fc2acdc60037c75c106b0e19deb5551d49ce67 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Thu, 23 Feb 2023 16:35:08 -0500 Subject: [PATCH] streamk fix (#836) Co-authored-by: Haicheng Wu --- .../gemm/kernel/gemm_universal_streamk.h | 35 ++++++------ tools/library/scripts/generator.py | 54 +++++++++++-------- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index 7a722cd6..eaa2a594 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -1068,7 +1068,6 @@ struct GemmUniversalStreamk { block_iters_remaining = block_iter_end - block_iter_begin; tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); - init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); } else @@ -1083,19 +1082,24 @@ struct GemmUniversalStreamk { return; } - // Perform this block's share of work for this tile - process_tile( - tile_work, - block_idx, - dp_start_block_idx, - block_iter_begin); - - block_iters_remaining -= tile_work.k_iters_remaining; - // Iteration-processing loop body CUTLASS_PRAGMA_NO_UNROLL - while (block_iters_remaining != 0) + while (true) { + // Perform this block's share of work for this tile + process_tile( + tile_work, + block_idx, + dp_start_block_idx, + block_iter_begin); + + block_iters_remaining -= tile_work.k_iters_remaining; + + if (block_iters_remaining == 0) + { + break; + } + // Continue to next tile __syncthreads(); @@ -1111,15 +1115,6 @@ struct GemmUniversalStreamk { tile_idx--; init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); } - - // Perform this block's share of work for this tile - process_tile( - tile_work, - block_idx, - dp_start_block_idx, - block_iter_begin); - - block_iters_remaining -= tile_work.k_iters_remaining; } } diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 183a63df..beb5b8ec 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -47,17 +47,20 @@ def product(X, identity = 1): elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps return min(max_alignment, elements_per_thread) +def DefaultSwizzlingFunctor(): + return SwizzlingFunctor.Identity8; + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + # def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): - # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + swizzling_functor = DefaultSwizzlingFunctor()): if complex_transforms is None: complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] element_a, element_b, element_c, element_epilogue = data_type - + operations = [] # by default, only generate the largest tile and largest alignment @@ -69,9 +72,9 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ for tile_description in tile_descriptions: for alignment in alignment_constraints: for complex_transform in complex_transforms: - + alignment_c = min(8, alignment) - + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) C = TensorDescription(element_c, layout[2], alignment_c) @@ -101,7 +104,7 @@ def CreateGemmUniversal3xOperator( # by default, only generate the largest tile and largest alignment if manifest.kernel_filter == '': - tile_descriptions = [tile_descriptions[0],] + tile_descriptions = [tile_descriptions[0]] for layout in layouts: for tile_description in tile_descriptions: @@ -419,7 +422,8 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme ] # Instance group conv kernel - if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC: + if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \ + tile.minimum_compute_capability >= 80: # SingleGroup kernel new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) @@ -526,9 +530,8 @@ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_ manifest.append(new_operation) operations.append(new_operation) - - return operations + return operations # Convolution for 2D operations specialized for few channels def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ @@ -572,7 +575,7 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty manifest.append(new_operation) operations.append(new_operation) - + return operations # Convolution for 3D operations @@ -1427,6 +1430,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): max_cc = 1024 alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] for math_inst in math_instructions: tile_descriptions = [ @@ -1471,10 +1475,12 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, - data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: if op.tile_description.threadblock_shape[1] >= 128: @@ -2110,6 +2116,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): smem_usage = 164 alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] for math_inst in math_instructions: tile_descriptions = [ @@ -2133,22 +2140,28 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - - operations = [] - - operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + for op in operations: if op.tile_description.threadblock_shape[1] >= 128: op.C.alignment = 16 @@ -4836,7 +4849,6 @@ def GenerateSM90(manifest, cuda_version): GenerateSM75(manifest, args.cuda_version) GenerateSM80(manifest, args.cuda_version) GenerateSM90(manifest, args.cuda_version) - if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library)