From 612876b593ebfc35f4bcfeee7a7ef67a4cf18649 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Thu, 30 Mar 2023 12:31:08 -0700 Subject: [PATCH] fix split_k_mode and add reduction kernel for f16 input/accum/output (#896) --- .../reduction/init_reduction_operations.cu | 2 ++ .../library/src/reduction/reduction_device.cu | 34 +++++++++++++++++++ tools/profiler/src/gemm_operation_profiler.cu | 5 ++- tools/profiler/src/gemm_operation_profiler.h | 5 +-- 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/tools/library/src/reduction/init_reduction_operations.cu b/tools/library/src/reduction/init_reduction_operations.cu index b0f16952..bd8d9bb1 100644 --- a/tools/library/src/reduction/init_reduction_operations.cu +++ b/tools/library/src/reduction/init_reduction_operations.cu @@ -42,6 +42,7 @@ namespace library { /////////////////////////////////////////////////////////////////////////////////////////////// // CUTLASS Reduction Instances // /////////////////////////////////////////////////////////////////////////////////////////////// +void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest); void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest); @@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) // void initialize_all_reduction_op(Manifest &manifest) { + initialize_reduce_add_linear_combination_f16_f16_f16(manifest); initialize_reduce_add_linear_combination_f32_f32_f16(manifest); initialize_reduce_add_linear_combination_f32_f32_f32(manifest); initialize_reduce_add_linear_combination_f64_f64_f64(manifest); diff --git a/tools/library/src/reduction/reduction_device.cu b/tools/library/src/reduction/reduction_device.cu index 2eb6ab7c..dfe8568b 100644 --- a/tools/library/src/reduction/reduction_device.cu +++ b/tools/library/src/reduction/reduction_device.cu @@ -43,6 +43,40 @@ namespace library { // naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput] +void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) { + + using ElementWorkspace = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >; + + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK< + cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + > + >; + + manifest.append(new ReductionOperation< + Operation_reduce_add_linear_combination_f16_f16_f16>( + "reduce_add_linear_combination_f16_f16_f16" + )); +} + void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { using ElementWorkspace = float; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 0924c033..a929ee89 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -62,7 +62,6 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): library::OperationKind::kGemm, { {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, - {ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"}, {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, @@ -71,6 +70,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, + {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"}, {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, }, @@ -298,8 +298,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); - set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); - set_argument(result, "A", problem_space, std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); @@ -313,6 +311,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "n", problem_space, n); set_argument(result, "k", problem_space, k); + set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index efee650f..a01c93a0 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -66,9 +66,8 @@ class GemmOperationProfiler : public OperationProfiler { /// Problem structure obtained from problem space struct GemmProblem { - + cutlass::library::GemmUniversalMode mode; - cutlass::library::SplitKMode split_k_mode; int64_t m; int64_t n; int64_t k; @@ -77,6 +76,8 @@ class GemmOperationProfiler : public OperationProfiler { int64_t ldc; std::vector alpha; std::vector beta; + + cutlass::library::SplitKMode split_k_mode; int split_k_slices; int batch_count;