Skip to content

Commit

Permalink
fix split_k_mode and add reduction kernel for f16 input/accum/output …
Browse files Browse the repository at this point in the history
…(#896)
  • Loading branch information
Manish Gupta authored and ttl10101 committed Feb 7, 2024
1 parent 8877318 commit 612876b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
2 changes: 2 additions & 0 deletions tools/library/src/reduction/init_reduction_operations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////
// CUTLASS Reduction Instances //
///////////////////////////////////////////////////////////////////////////////////////////////
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
Expand All @@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
//
void initialize_all_reduction_op(Manifest &manifest) {

initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);
Expand Down
34 changes: 34 additions & 0 deletions tools/library/src/reduction/reduction_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,40 @@ namespace library {

// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]

void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {

using ElementWorkspace = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementCompute = cutlass::half_t;

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;

using ReductionOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
typename EpilogueOutputOp::ElementAccumulator,
EpilogueOutputOp::kCount
>;

using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>
>;

manifest.append(new ReductionOperation<
Operation_reduce_add_linear_combination_f16_f16_f16>(
"reduce_add_linear_combination_f16_f16_f16"
));
}

void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {

using ElementWorkspace = float;
Expand Down
5 changes: 2 additions & 3 deletions tools/profiler/src/gemm_operation_profiler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
library::OperationKind::kGemm,
{
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
Expand All @@ -71,6 +70,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
{ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"},
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
},
Expand Down Expand Up @@ -298,8 +298,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result(

set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));

set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));

set_argument(result, "A", problem_space,
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));

Expand All @@ -313,6 +311,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "n", problem_space, n);
set_argument(result, "k", problem_space, k);

set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "split_k_slices", problem_space, split_k_slices);
set_argument(result, "batch_count", problem_space, batch_count);

Expand Down
5 changes: 3 additions & 2 deletions tools/profiler/src/gemm_operation_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ class GemmOperationProfiler : public OperationProfiler {

/// Problem structure obtained from problem space
struct GemmProblem {

cutlass::library::GemmUniversalMode mode;
cutlass::library::SplitKMode split_k_mode;
int64_t m;
int64_t n;
int64_t k;
Expand All @@ -77,6 +76,8 @@ class GemmOperationProfiler : public OperationProfiler {
int64_t ldc;
std::vector<uint8_t> alpha;
std::vector<uint8_t> beta;

cutlass::library::SplitKMode split_k_mode;
int split_k_slices;
int batch_count;

Expand Down

0 comments on commit 612876b

Please sign in to comment.