diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 4a95717d..28963203 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -45,7 +45,7 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/device/gemm_with_k_reduction.h" #include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" #include "cutlass/reduction/device/reduce_split_k.h" #include "cutlass/reduction/kernel/reduce_split_k.h" @@ -101,6 +101,12 @@ constexpr int NumStages = 4; // Reduce A or B operand along the K dimension constexpr bool ReduceKForA = true; +// Alignment of A operand +constexpr int AlignmentA = 8; + +// Alignment of B operand +constexpr int AlignmentB = 8; + // This code section describes the epilogue part of the kernel, we use default value using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. @@ -110,9 +116,9 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementAccumulator, // Data type of accumulator ElementComputeEpilogue>; -using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< - ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8, - ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8, +using Gemm = typename cutlass::gemm::device::GemmWithKReduction< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, @@ -124,10 +130,12 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< EpilogueOp, SwizzleThreadBlock, NumStages, - cutlass::arch::OpMultiplyAdd ->::GemmKernel; - -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + AlignmentA, + AlignmentB, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone +>; // Below is the reduction kernel used in the case of parallel split-k using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; @@ -368,21 +376,21 @@ Result profile(Options const &options) { // Fill input and output matrices on host using CUTLASS helper functions cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), - 1, + 1997, ElementInputA(2), ElementInputA(-2), 0); // <- Fill tensor A on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), - 1, + 2003, ElementInputB(2), ElementInputB(-2), 0); // <- Fill tensor B on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), - 1, + 2017, ElementOutput(2), ElementOutput(-2), 0); // <- Fill matrix C on host with uniform-distribution random data @@ -561,7 +569,7 @@ Result profile(Options const &options) { tensor_reduction.sync_host(); - // ReduceK in host code + // Reduce K in host code if (ReduceKForA) { for (int m = 0; m < options.problem_size.m(); ++m) { for (int k = 0; k < options.problem_size.k(); ++k) { @@ -581,7 +589,7 @@ Result profile(Options const &options) { // Check if output from CUTLASS kernel and reference kernel are equal or not bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); - + pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(), tensor_reduction.host_view()); diff --git a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h index 63cd4c6e..4bab4e5f 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -149,13 +149,12 @@ class EpilogueGemmKReduction { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations / 4; ++i) { - ElementOutput tmp; + ElementOutput *source_ptr = reinterpret_cast(&source); cutlass::arch::global_load( - tmp, + source_ptr[i], (void *)(pointer_ + i * 32), guard[i] && LoadForSerialSplitK); - source[i] = tmp; } FragmentAccumulator sum = gemm_k_with_reduction_accumulation; diff --git a/include/cutlass/gemm/device/gemm_with_k_reduction.h b/include/cutlass/gemm/device/gemm_with_k_reduction.h new file mode 100644 index 00000000..254a7f96 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_with_k_reduction.h @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" + +#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Reduce A or B operand along the K dimension + bool ReduceKForA_ = true, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute +> +class GemmWithKReduction : + public GemmUniversalBase< + typename kernel::DefaultGemmWithKReduction< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ReduceKForA_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static constexpr int kStages = Stages; + static constexpr int kAlignmentA = AlignmentA; + static constexpr int kAlignmentB = AlignmentB; + static constexpr int kAlignmentC = EpilogueOutputOp::kCount; + static constexpr ComplexTransform kTransformA = TransformA; + static constexpr ComplexTransform kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmWithKReduction< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ReduceKForA_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Reduce A or B operand along the K dimension + bool ReduceKForA_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout +> +class GemmWithKReduction { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmWithKReduction< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + !ReduceKForA_, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA, + GatherB, + GatherA, + ScatterD, + PermuteDLayout + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmWithKReduction() = default; + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h index ddb78b2c..9658af2b 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h @@ -91,7 +91,7 @@ template < typename ElementAccumulator, /// Operator class tag typename OperatorClass, - /// + /// Reduce A or B along the K dimension bool ReduceKForA_, /// Tag indicating architecture to tune for typename ArchTag, diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index ab69d7f6..93e5ed43 100644 --- a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -41,6 +41,7 @@ #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/layout/pitch_linear.h" #include "cutlass/trace.h" diff --git a/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h index 57be0c3a..bae24fca 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h @@ -90,7 +90,7 @@ template < typename LayoutC, /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) typename OperatorClass, - /// + /// Reduce operand A or B along K dimension bool ReduceKForA_, /// Number of stages int Stages = 2, diff --git a/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h index ace2d0f8..c864902f 100644 --- a/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h @@ -61,7 +61,9 @@ template < /// Layout of C matrix (concept: MatrixLayout) typename LayoutC, /// Operator describing the tensor operation - typename Operator_ = arch::OpMultiplyAdd, + typename Operator_, + /// Reduce operand A or B along K dimension + bool ReduceKForA_, /// Number of partitions along K dimension int PartitionsK = 1, /// Store the accumulators in row major or column major. Row major is used @@ -78,7 +80,7 @@ struct DefaultMmaWithReductionTensorOp { // Define the warp-level tensor op using Type = cutlass::gemm::warp::MmaWithReductionTensorOp< WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - Policy, PartitionsK, AccumulatorsInRowMajor>; + Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h index 420a8a50..7041689d 100644 --- a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h @@ -81,7 +81,7 @@ template < typename LayoutC_, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) typename Policy_, - /// + /// Reduce operand A or B along K dimension bool ReduceKForA_, /// Number of partitions along K dimension int PartitionsK_ = 1,