Skip to content

Commit

Permalink
streamk fix (#836)
Browse files Browse the repository at this point in the history
Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
2 people authored and ttl10101 committed Feb 7, 2024
1 parent 89e90ad commit 19fc2ac
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 41 deletions.
35 changes: 15 additions & 20 deletions include/cutlass/gemm/kernel/gemm_universal_streamk.h
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,6 @@ struct GemmUniversalStreamk {
block_iters_remaining = block_iter_end - block_iter_begin;

tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);

init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
else
Expand All @@ -1083,19 +1082,24 @@ struct GemmUniversalStreamk {
return;
}

// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);

block_iters_remaining -= tile_work.k_iters_remaining;

// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (block_iters_remaining != 0)
while (true)
{
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);

block_iters_remaining -= tile_work.k_iters_remaining;

if (block_iters_remaining == 0)
{
break;
}

// Continue to next tile
__syncthreads();

Expand All @@ -1111,15 +1115,6 @@ struct GemmUniversalStreamk {
tile_idx--;
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}

// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);

block_iters_remaining -= tile_work.k_iters_remaining;
}

}
Expand Down
54 changes: 33 additions & 21 deletions tools/library/scripts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ def product(X, identity = 1):
elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps
return min(max_alignment, elements_per_thread)

def DefaultSwizzlingFunctor():
return SwizzlingFunctor.Identity8;
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`

#
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
swizzling_functor = SwizzlingFunctor.Identity8):
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
swizzling_functor = DefaultSwizzlingFunctor()):

if complex_transforms is None:
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]

element_a, element_b, element_c, element_epilogue = data_type

operations = []

# by default, only generate the largest tile and largest alignment
Expand All @@ -69,9 +72,9 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
for complex_transform in complex_transforms:

alignment_c = min(8, alignment)

A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
C = TensorDescription(element_c, layout[2], alignment_c)
Expand Down Expand Up @@ -101,7 +104,7 @@ def CreateGemmUniversal3xOperator(

# by default, only generate the largest tile and largest alignment
if manifest.kernel_filter == '':
tile_descriptions = [tile_descriptions[0],]
tile_descriptions = [tile_descriptions[0]]

for layout in layouts:
for tile_description in tile_descriptions:
Expand Down Expand Up @@ -419,7 +422,8 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
]

# Instance group conv kernel
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \
tile.minimum_compute_capability >= 80:
# SingleGroup kernel
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
Expand Down Expand Up @@ -526,9 +530,8 @@ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_

manifest.append(new_operation)
operations.append(new_operation)

return operations

return operations

# Convolution for 2D operations specialized for few channels
def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \
Expand Down Expand Up @@ -572,7 +575,7 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty

manifest.append(new_operation)
operations.append(new_operation)

return operations

# Convolution for 3D operations
Expand Down Expand Up @@ -1427,6 +1430,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
max_cc = 1024

alignment_constraints = [16,]
alignment_constraints_small_channels = [16, 8, 4]

for math_inst in math_instructions:
tile_descriptions = [
Expand Down Expand Up @@ -1471,10 +1475,12 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):

operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)

operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)

operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)

for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
Expand Down Expand Up @@ -2110,6 +2116,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
smem_usage = 164

alignment_constraints = [16,]
alignment_constraints_small_channels = [16, 8, 4]

for math_inst in math_instructions:
tile_descriptions = [
Expand All @@ -2133,22 +2140,28 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):

data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]

CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)

operations = []

operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)


conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)

operations = []

operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)

operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)


operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)

operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)

for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 16
Expand Down Expand Up @@ -4836,7 +4849,6 @@ def GenerateSM90(manifest, cuda_version):
GenerateSM75(manifest, args.cuda_version)
GenerateSM80(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)

if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)

Expand Down

0 comments on commit 19fc2ac

Please sign in to comment.