From 96181dabfb9daa89f866d1a5858dbe48a52b2ec7 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 27 Sep 2023 08:18:30 -0700 Subject: [PATCH] Support for Mixed Input TensorOp (#1084) * Passing warp-level mixed input F16*(S8/U8) tests * passing device-level mixed input F16*(S8/U8) tests * add to profiler - I8 (111 TFLOPs), U (123 TFLOPs) * fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs) * Speedup reference compilation (REVERT THIS COMMIT) * wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s) * Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs * BF16 * S8 (142 TFLOPs) * Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16] * rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast * Add device-level test and profiler support for upcast on operand A * Move shfl before the cvt and reduce #shfls by 1/2 * fix smem_usage calculation for mixed_input types * uncomment the stuff (getting ready for merge) * profiler changes and mixed-input reference * mixed input reference are in a new file * use platform instead of std * comments and typo only * Use CreateGemmOperator and delete CreateMixedInputGemmOperator * copyright for new files * rebase follow-up --- include/cutlass/arch/mma.h | 10 + .../gemm/warp/default_mma_tensor_op_sm80.h | 67 +++ .../gemm/warp/mma_mixed_input_tensor_op.h | 554 ++++++++++++++++++ include/cutlass/numeric_conversion.h | 222 ++++++- python/cutlass_library/conv2d_operation.py | 5 + python/cutlass_library/conv3d_operation.py | 6 +- python/cutlass_library/gemm_operation.py | 11 +- python/cutlass_library/generator.py | 124 +++- python/cutlass_library/library.py | 11 +- python/cutlass_library/rank_2k_operation.py | 4 + python/cutlass_library/rank_k_operation.py | 4 + python/cutlass_library/symm_operation.py | 4 + python/cutlass_library/trmm_operation.py | 4 + test/unit/core/CMakeLists.txt | 1 + test/unit/core/fast_numeric_conversion.cu | 176 ++++++ test/unit/gemm/device/CMakeLists.txt | 15 + ...s8n_f16t_mixed_input_tensor_op_f16_sm80.cu | 97 +++ ...u8n_f16t_mixed_input_tensor_op_f16_sm80.cu | 97 +++ ...16n_f16t_mixed_input_tensor_op_f16_sm80.cu | 97 +++ ...16n_f16t_mixed_input_tensor_op_f16_sm80.cu | 97 +++ test/unit/gemm/device/testbed_universal.h | 9 +- test/unit/gemm/warp/CMakeLists.txt | 1 + test/unit/gemm/warp/gemm_mixed_input_sm80.cu | 322 ++++++++++ tools/library/CMakeLists.txt | 1 + .../src/reference/gemm_fp_mixed_input.cu | 138 +++++ .../initialize_reference_operations.cu | 2 + 26 files changed, 2065 insertions(+), 14 deletions(-) create mode 100644 include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h create mode 100644 test/unit/core/fast_numeric_conversion.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu create mode 100644 test/unit/gemm/warp/gemm_mixed_input_sm80.cu create mode 100644 tools/library/src/reference/gemm_fp_mixed_input.cu diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index f7c59e63..7a70114d 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {}; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag indicating the input data types are mixed and the narrower type is +/// upcasted to the wider type +struct OpMultiplyAddMixedInputUpcast {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every F32 output element struct OpMultiplyAddFastF32 {}; +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every complex output element struct OpMultiplyAddComplexFastF32 {}; +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Helper for determining whether staged accumulation should be used for a given operator template struct UseStagedAccumulation { diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index d4d8026a..9572f2e3 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -38,6 +38,7 @@ #include "cutlass/numeric_types.h" #include "cutlass/arch/mma.h" #include "cutlass/gemm/warp/mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h" #include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" #include "cutlass/gemm/warp/default_mma_tensor_op.h" @@ -227,6 +228,72 @@ struct DefaultMmaTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial Specialization - inputs are mixed types - uses wider datatype internally. +/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32) +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Element type of A matrix + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Element type of B matrix + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + GemmShape<16, 8, 16>, // InstructionShape + ElementA, // Element type of A matrix in Global Memory + LayoutA, // Layout of A matrix in Global Memory + ElementB, // Element type of B matrix in Global Memory + LayoutB, // Layout of B matrix in Global Memory + ElementC, // Element type of C matrix in Global Memory + LayoutC, // Layout of C matrix in Global Memory + arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype + PartitionsK, AccumulatorsInRowMajor> { + + + // Check if the ElementA and ElementB are of different data types + static_assert(!platform::is_same::value, + "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); + + // Data type used for internal computation - use the wider of the two data types for mma.sync operands + using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)), + ElementA, ElementB>::type; + + // Operand datatypes in the internal MMA instruction - use the wider of the two data types + using MmaElementA = ElementOperand; + using MmaElementB = ElementOperand; + using MmaElementC = ElementC; + + // Uses + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 16>, + 32, + MmaElementA, cutlass::layout::RowMajor, + MmaElementB, cutlass::layout::ColumnMajor, + MmaElementC, cutlass::layout::RowMajor, + arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + } // namespace warp } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h new file mode 100644 index 00000000..ee58e39d --- /dev/null +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -0,0 +1,554 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +//////////////////////////////////////////////////////////////////////////////// +// Shuffle registers for layout conversion +//////////////////////////////////////////////////////////////////////////////// +template < + /// Element type for the operand in registers for the mma.sync + typename ElementMma_, + /// Element type for the operand in shared memory for ldmatrix + typename ElementLoad_, + /// Number of mma.sync operations performed along rows or columns + int NumMmaInstructions, + /// Number of elements in warp fragment + int NumElementsInWarpFragment, + /// Number of elements in mma fragment + int NumElementsInMmaFragment, + /// Identifies A or B multiplicand + Operand Operand_, + /// + typename Enable = void > +struct FragmentShuffler { + public: + using ElementMma = ElementMma_; + using ElementLoad = ElementLoad_; + + static int const kNumMmaInstructions = NumMmaInstructions; + static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; + static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; + static Operand const kOperand = Operand_; + + using WarpFragment = Array; + using MmaFragment = Array; + + CUTLASS_DEVICE + WarpFragment operator()(WarpFragment const &src) { + return src; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) +/// for operand A multiplicand going through upcasting. +template < + /// Element type for the operand in registers for the mma.sync + typename ElementMma_, + /// Element type for the operand in shared memory for ldmatrix + typename ElementLoad_, + /// Number of mma.sync operations performed along rows or columns + int NumMmaInstructions, + /// Number of elements in warp fragment + int NumElementsInWarpFragment, + /// Number of elements in mma fragment + int NumElementsInMmaFragment +> +struct FragmentShuffler ::value == 16) && + (sizeof_bits::value == 8)>::type> { +public: + using ElementMma = ElementMma_; + using ElementLoad = ElementLoad_; + + static int const kNumMmaInstructions = NumMmaInstructions; + static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; + static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; + static Operand const kOperand = Operand::kA; + + using WarpFragment = Array; + using MmaFragment = Array; + + static uint32_t const kSelectBytesEvenThread = 0x5410; + static uint32_t const kSelectBytesOddThread = 0x7632; + +private: + int delta_up_; + int delta_down_; + int odd_even_lane_id_; + uint32_t byte_selector_; + +public: + CUTLASS_DEVICE + FragmentShuffler() { + int lane_id = cutlass::arch::LaneId(); + delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1); + delta_down_ = 2 - delta_up_; + odd_even_lane_id_ = static_cast(lane_id & 1); + byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread + + (1 - odd_even_lane_id_) * kSelectBytesEvenThread; + } + + CUTLASS_DEVICE + WarpFragment operator()(WarpFragment const &src) { + + WarpFragment result; + MmaFragment const* mma_frag_src_ptr = reinterpret_cast(&src); + MmaFragment* mma_frag_dst_ptr = reinterpret_cast(&result); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kNumMmaInstructions; n++) { + + uint32_t const* src_ptr = reinterpret_cast(&mma_frag_src_ptr[n]); + uint32_t *dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[n]); + + // Shuffle data within the warp, pull from other threads within the warp + uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_); + uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_); + uint32_t tmp2 = __shfl_up_sync(0xFFFFFFFF, src_ptr[1], delta_up_); + uint32_t tmp3 = __shfl_down_sync(0xFFFFFFFF, src_ptr[1], delta_down_); + + // Reorder the data within the 32-bit word (4x8b) required for mma.sync + dst_ptr[0] = __byte_perm(tmp0, tmp2, byte_selector_); + dst_ptr[1] = __byte_perm(tmp1, tmp3, byte_selector_); + } + + return result; + } + +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) +/// for operand B multiplicand going through upcasting. +template < + /// Element type for the operand in registers for the mma.sync + typename ElementMma_, + /// Element type for the operand in shared memory for ldmatrix + typename ElementLoad_, + /// Number of mma.sync operations performed along rows or columns + int NumMmaInstructions, + /// Number of elements in warp fragment + int NumElementsInWarpFragment, + /// Number of elements in mma fragment + int NumElementsInMmaFragment +> +struct FragmentShuffler ::value == 16) && + (sizeof_bits::value == 8)>::type> { +public: + using ElementMma = ElementMma_; + using ElementLoad = ElementLoad_; + + static int const kNumMmaInstructions = NumMmaInstructions; + static int const kNumElementsInWarpFragment = NumElementsInWarpFragment; + static int const kNumElementsInMmaFragment = NumElementsInMmaFragment; + static Operand const kOperand = Operand::kB; + + using WarpFragment = Array; + using MmaFragment = Array; + + static uint32_t const kSelectBytesEvenThread = 0x5410; + static uint32_t const kSelectBytesOddThread = 0x7632; + +private: + int delta_up_; + int delta_down_; + int odd_even_lane_id_; + uint32_t byte_selector_; + +public: + CUTLASS_DEVICE + FragmentShuffler() { + int lane_id = cutlass::arch::LaneId(); + delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1); + delta_down_ = 2 - delta_up_; + odd_even_lane_id_ = static_cast(lane_id & 1); + byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread + + (1 - odd_even_lane_id_) * kSelectBytesEvenThread; + } + + CUTLASS_DEVICE + WarpFragment operator()(WarpFragment const &src) { + + WarpFragment result; + + MmaFragment const* mma_frag_src_ptr = reinterpret_cast(&src); + MmaFragment* mma_frag_dst_ptr = reinterpret_cast(&result); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kNumMmaInstructions; n++) { + + uint32_t const* src_ptr = reinterpret_cast(&mma_frag_src_ptr[n]); + uint32_t* dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[n]); + + // Shuffle data within the warp, pull from other threads within the warp + uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_); + uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_); + + // Reorder the data within the 32-bit word (4x8b) required for mma.sync + dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_); + } + + return result; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// +// Data type conversion +//////////////////////////////////////////////////////////////////////////////// +template < + /// Destination type + typename ElementDst_, + /// Source type + typename ElementSrc_, + /// Number of elements + int N, + /// + typename Enable = void> +struct FragmentConverter { + + using ElementDst = ElementDst_; + using ElementSrc = ElementSrc_; + + // Operand fragment registers in destination and source types + using DestinationFragment = Array; + using SourceFragment = Array; + + FastNumericArrayConverter convert; + + CUTLASS_DEVICE + DestinationFragment operator()(SourceFragment const &src) const { + return convert(src); + } +}; +//////////////////////////////////////////////////////////////////////////////// + +// Partial specialization for when Destination type is the *same* as +// Source type +template < + /// Data type + typename Element, + /// Number of elements + int N, + /// + typename Enable> +struct FragmentConverter { + + using DestinationFragment = Array; + using SourceFragment = Array; + + CUTLASS_DEVICE + DestinationFragment operator()(SourceFragment const &src) const { + return src; + } +}; + +} // namespace detail + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class MmaMixedInputTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Underlying arch::Mma instruction datatype for A operand + using MmaElementA = typename ArchMmaOperator::ElementA; + + /// Underlying arch::Mma instruction datatype for B operand + using MmaElementB = typename ArchMmaOperator::ElementB; + + /// Underlying arch::Mma instruction datatype for C operand + using MmaElementC = typename ArchMmaOperator::ElementC; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// + // static int const kLoadShapeK = InstructionShape::kK * + // (sizeof_bits::value / sizeof_bits::value); + +public: + + /// Iterates over the A operand in Shared Memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile in registers (loaded from Shared Memory) + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile in registers (for use in Mma instruction) + using TransformedFragmentA = + Array; + + /// Underlying arch::Mma instruction operand fragement for matrix A + using MmaOperandA = typename ArchMmaOperator::FragmentA; + + /// Iterates over the B operand in Shared Memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for B tile in registers (loaded from Shared Memory) + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile in registers (for use in Mma instruction) + using TransformedFragmentB = + Array; + + /// Underlying arch::Mma instruction operand fragement for matrix B + using MmaOperandB = typename ArchMmaOperator::FragmentB; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Underlying arch::Mma instruction operand fragement for matrix C + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaMixedInputTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + } + + /// Transform the operand warp fragment register to the required data types and layout + /// for the `cultass::arch::Mma` + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + + // Shuffle data within warp to obtain the mma.sync operand layout + detail::FragmentShuffler shuffler_A; + FragmentA tmp_A; + tmp_A = shuffler_A(A); + + // Convert the A operand to the Mma Instruction operand type + detail::FragmentConverter convert_A; + dst_A = convert_A(tmp_A); + + + // Shuffle data within warp to obtain the mma.sync operand layout + detail::FragmentShuffler shuffler_B; + FragmentB tmp_B; + tmp_B = shuffler_B(B); + + // Convert the B operand to the Mma Instruction operand type + detail::FragmentConverter convert_B; + dst_B = convert_B(tmp_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index f53bb731..a3ad138b 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -2340,7 +2340,8 @@ struct NumericArrayConverter { /// Conversion operator for Array. See the comments before /// FastLinearCombinationClamp. template + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, + typename Enable = void> struct FastNumericArrayConverter { using result_type = Array; using source_type = Array; @@ -2441,6 +2442,225 @@ struct FastNumericArrayConverter { result_type operator()(source_type const &s) const { return convert(s); } }; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + int16_t tmp = source[i] + 26112 /* 0x6600 */; + result[i] = reinterpret_cast(tmp) - 1536.0_hf; + } + #endif + + // Vectorized s8->f16 conversion using packed instructions + uint32_t const* source_ptr = reinterpret_cast(&source); + uint32_t* result_ptr = reinterpret_cast(&result); + + // Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0]) + // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) + // The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4 + // into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively. + // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same and doesn't sign extend the sign-bit. + // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from `s8x2` to `s16x2`. + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180)); + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2)); + + // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve + // the same result as add.s16x2 instruction. + // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) + // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to + // three predefined constant values as follows: + // ta = 0xF0; + // tb = 0xCC; + // tc = 0xAA; + // kImmLut = F(ta, tb, tc); + // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA + static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; + + // The bit-wise operation executed below is `result_ptr[0] = (result_ptr[0] & 0x03FF03FF) ^ 0x66006600;` + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : + "=r"(result_ptr[0]) : "r"(result_ptr[0]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); + // The bit-wise operation executed below is `result_ptr[1] = (result_ptr[1] & 0x03FF03FF) ^ 0x66006600;` + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : + "=r"(result_ptr[1]) : "r"(result_ptr[1]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); + + // Packed sub.f16x2 with magic number to obtain final converted result + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + uint32_t const* source_ptr = reinterpret_cast(&source); + uint32_t* result_ptr = reinterpret_cast(&result); + + result_ptr[0] = __byte_perm(source_ptr[0], 0x0, 0x4140); + result_ptr[1] = __byte_perm(source_ptr[0], 0x0, 0x4342); + + asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); + asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + Array tmp; + + uint32_t const* source_ptr = reinterpret_cast(&source); + uint32_t* tmp_ptr = reinterpret_cast(&tmp); + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores + // the result in tmp (without introducing extra cvt.u32.u8 instruction) + tmp_ptr[0] = __byte_perm(source_ptr[0], 0x4B000000, 0x7650); + tmp_ptr[1] = __byte_perm(source_ptr[0], 0x4B000000, 0x7651); + tmp_ptr[2] = __byte_perm(source_ptr[0], 0x4B000000, 0x7652); + tmp_ptr[3] = __byte_perm(source_ptr[0], 0x4B000000, 0x7653); + + // Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + tmp[i] = reinterpret_cast(tmp_ptr[i]) - 8388608.f; + } + + // on 3456x4096x8192 runs at 158 TFLOP/s + // Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction + NumericArrayConverter convert_f32_to_bf16; + result = convert_f32_to_bf16(tmp); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + using intermediate_float_type = Array; + using intermediate_int32_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + intermediate_float_type tmp; + + uint32_t const* source_ptr = reinterpret_cast(&source); + uint32_t* tmp_ptr = reinterpret_cast(&tmp); + + // s8x4 (s[3], s[2], s8[1], s8[0]) -> s16x4 (sext.s8[3], sext.s8[2], sext.s8[1], sext.s8[0]) + // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) + // The inline ptx below uses `msb=0` and `msb=1` from the above link to sext the sign-bit in 0, 1, 2, 3 bytes of s8x4 + // sext without unpacking each s8 out of s8x4 into a separate register a.ka. without using shifts (SHFL). + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[0]) : "r"(source_ptr[0]), "n"(0x8880)); + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[1]) : "r"(source_ptr[0]), "n"(0x9991)); + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[2]) : "r"(source_ptr[0]), "n"(0xAAA2)); + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[3]) : "r"(source_ptr[0]), "n"(0xBBB3)); + + // Convert s32x4 to f32x4 using fast numeric array converter + FastNumericArrayConverter convert_s32_to_f32_; + tmp = convert_s32_to_f32_(reinterpret_cast(tmp[0])); + + // Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction + NumericArrayConverter convert_f32_to_bf16_; + result = convert_f32_to_bf16_(tmp); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for FastNumericArrayConverter to vectorize over 4 elements. +/// source `S` as 8b integers (S8 or U8) -> destination `T` as 16b floating-point (F16 or BF16) +template +struct FastNumericArrayConverter::value || platform::is_same::value) && + (platform::is_same::value || platform::is_same::value)>::type> { + static_assert(!(N % 4), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + FastNumericArrayConverter convert_vector_; + result_type result; + + Array *result_ptr = + reinterpret_cast *>(&result); + Array const *source_ptr = + reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { return convert(s); } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines preferred rounding mode for a pair of types diff --git a/python/cutlass_library/conv2d_operation.py b/python/cutlass_library/conv2d_operation.py index b59771ef..fcfcd24a 100644 --- a/python/cutlass_library/conv2d_operation.py +++ b/python/cutlass_library/conv2d_operation.py @@ -62,6 +62,11 @@ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, self.stride_support = stride_support self.swizzling_functor = swizzling_functor self.group_mode = group_mode + + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def is_complex(self): complex_operators = [ diff --git a/python/cutlass_library/conv3d_operation.py b/python/cutlass_library/conv3d_operation.py index 0a3265bb..5ab1b900 100644 --- a/python/cutlass_library/conv3d_operation.py +++ b/python/cutlass_library/conv3d_operation.py @@ -60,7 +60,11 @@ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor - + + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def core_name(self): ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index e92b891f..ad62422c 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -88,6 +88,10 @@ def is_complex(self): ] return self.tile_description.math_instruction.math_operation in complex_operators + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def is_planar_complex(self): return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) @@ -149,14 +153,19 @@ def extended_name(self): if self.C.element != self.tile_description.math_instruction.element_accumulator and \ self.A.element != self.tile_description.math_instruction.element_accumulator: extended_name = "${element_c}_${core_name}_${element_a}" + if self.is_mixed_input(): + extended_name += "_${element_b}" elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ self.A.element != self.tile_description.math_instruction.element_accumulator: extended_name = "${core_name}_${element_a}" + if self.is_mixed_input(): + extended_name += "_${element_b}" else: extended_name = "${core_name}" extended_name = SubstituteTemplate(extended_name, { 'element_a': DataTypeNames[self.A.element], + 'element_b': DataTypeNames[self.B.element], 'element_c': DataTypeNames[self.C.element], 'core_name': self.core_name() }) @@ -235,7 +244,7 @@ def procedural_name(self): ex = self.extended_name(), tb = threadblock, l = self.layout_name(), - a = str(self.A.alignment)) + a = str(max(self.A.alignment, self.B.alignment))) # def configuration_name(self): diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index facd5d96..ee6bb2ce 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -103,11 +103,14 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ for tile_description in tile_descriptions: for alignment in alignment_constraints: for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) - B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + + # If alignment is a tuple or a list, then we have different alignments for A and B + alignment_a = alignment if isinstance(alignment, int) else alignment[0] + alignment_b = alignment if isinstance(alignment, int) else alignment[1] + alignment_c = min(8, alignment_a) + + A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) C = TensorDescription(element_c, layout[2], alignment_c) new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ @@ -2150,6 +2153,116 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the inner list + # contains the alignment constraints for [operandA, operandB]. + alignment_constraints = [[16, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the inner list + # contains the alignment constraints for [operandA, operandB]. + alignment_constraints = [[8, 16],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + # def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): @@ -4083,6 +4196,7 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_TensorOp_884_symm(manifest, cuda_version) GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) + GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version) GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index a1d75c21..66c7f940 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -289,6 +289,7 @@ class ComplexMultiplyOp(enum.Enum): class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() + multiply_add_mixed_input_upcast = enum_auto() xor_popc = enum_auto() and_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() @@ -302,6 +303,7 @@ class MathOperation(enum.Enum): MathOperationTag = { MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.and_popc: 'cutlass::arch::OpAndPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', @@ -964,8 +966,13 @@ def CalculateSmemUsage(operation): cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md else: # Few BLAS3 operations only have A tensor - smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ - DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8 + data_type_size_a = DataTypeSize[operation.A.element] + data_type_size_b = DataTypeSize[operation.A.element] + if operation.is_mixed_input(): + data_type_size_b = DataTypeSize[operation.B.element] + + smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 smem_usage = smem_per_stage * stages return (smem_usage >> 10) diff --git a/python/cutlass_library/rank_2k_operation.py b/python/cutlass_library/rank_2k_operation.py index 4b3bab30..dfa5f070 100644 --- a/python/cutlass_library/rank_2k_operation.py +++ b/python/cutlass_library/rank_2k_operation.py @@ -79,6 +79,10 @@ def is_complex(self): return self.tile_description.math_instruction.math_operation in complex_operators return False + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def is_planar_complex(self): return False diff --git a/python/cutlass_library/rank_k_operation.py b/python/cutlass_library/rank_k_operation.py index 993df7ca..5868d20d 100644 --- a/python/cutlass_library/rank_k_operation.py +++ b/python/cutlass_library/rank_k_operation.py @@ -77,6 +77,10 @@ def is_complex(self): return self.tile_description.math_instruction.math_operation in complex_operators return False + # + def is_mixed_input(self): + return False + # def is_planar_complex(self): return False diff --git a/python/cutlass_library/symm_operation.py b/python/cutlass_library/symm_operation.py index 5b2a1603..e97245b1 100644 --- a/python/cutlass_library/symm_operation.py +++ b/python/cutlass_library/symm_operation.py @@ -79,6 +79,10 @@ def is_complex(self): return self.tile_description.math_instruction.math_operation in complex_operators return False + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def is_planar_complex(self): return False diff --git a/python/cutlass_library/trmm_operation.py b/python/cutlass_library/trmm_operation.py index b2b0577f..fe2c1f93 100644 --- a/python/cutlass_library/trmm_operation.py +++ b/python/cutlass_library/trmm_operation.py @@ -81,6 +81,10 @@ def is_planar_complex(self): # return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray) return False + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def accumulator_type(self): accum = self.tile_description.math_instruction.element_accumulator diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index 0abcf71e..6c97ed7e 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -41,6 +41,7 @@ cutlass_test_unit_add_executable( tensor_view.cu matrix_coord.cu numeric_conversion.cu + fast_numeric_conversion.cu functional.cu ) diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu new file mode 100644 index 00000000..1eeb8e8d --- /dev/null +++ b/test/unit/core/fast_numeric_conversion.cu @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for conversion operators. +*/ + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/numeric_conversion.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/util/host_tensor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace core { +namespace kernel { + +/// Simple conversion function +template +__global__ void convert( + cutlass::Array *destination, + cutlass::Array const *source) { + + cutlass::FastNumericArrayConverter convert; + + *destination = convert(*source); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_test_integer_range_limited() { + const int kN = Count; + + dim3 grid(1, 1); + dim3 block(1, 1); + + cutlass::HostTensor destination({1, kN}); + cutlass::HostTensor source({1, kN}); + + for (int i = 0; i < kN; ++i) { + source.host_data()[i] = Source(i % 4); + } + + source.sync_device(); + + convert<<< grid, block >>>( + reinterpret_cast *>(destination.device_data()), + reinterpret_cast const *>(source.device_data()) + ); + + destination.sync_host(); + + for (int i = 0; i < kN; ++i) { + EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); + } +} + + +template +void run_test_integer_range_all() { + const int kN = Count; + + dim3 grid(1, 1); + dim3 block(1, 1); + + cutlass::HostTensor destination({1, kN}); + cutlass::HostTensor source({1, kN}); + + int const kIntSourceMin = std::numeric_limits::min(); + int const kIntSourceMax = std::numeric_limits::max(); + int const kIntRange = kIntSourceMax - kIntSourceMin + 1; + + for (int i = 0; i < kN; ++i) { + source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange)); + + } + + source.sync_device(); + + convert<<< grid, block >>>( + reinterpret_cast *>(destination.device_data()), + reinterpret_cast const *>(source.device_data()) + ); + + destination.sync_host(); + + // Verify conversion + bool passed = true; + for (int i = 0; i < kN; ++i) { + if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) { + passed = false; + break; + } + } + EXPECT_TRUE(passed) << " FastNumericArrayConverter failed"; + + // Print out results for the failed conversion. + if (!passed) { + for (int i = 0; i < kN; ++i) { + std::cout << "source(" << float(source.host_data()[i]) << ") -> " + << "destination ("<< float(destination.host_data()[i]) << ")" << std::endl; + } + } + std::flush(std::cout); +} + +} // namespace kernel +} // namespace core +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(FastNumericConversion, s32_to_f32) { + int const kN = 4; + using Source = int; + using Destination = float; + test::core::kernel::run_test_integer_range_limited(); +} + +TEST(FastNumericConversion, s8_to_f16_array) { + int const kN = 256; + using Source = int8_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test_integer_range_all(); +} + +TEST(FastNumericConversion, u8_to_f16_array) { + int const kN = 256; + using Source = uint8_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test_integer_range_all(); +} + +TEST(FastNumericConversion, u8_to_bf16_array) { + int const kN = 256; + using Source = uint8_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test_integer_range_all(); +} + +TEST(FastNumericConversion, s8_to_bf16_array) { + int const kN = 256; + using Source = int8_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test_integer_range_all(); +} diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 56deefdb..752239ab 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -341,6 +341,21 @@ cutlass_test_unit_add_executable( sm80_gemm_f16_f16_f32_tensor_op_f32.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + # Upcast on Operand A + gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + + # Upcast on Operand B + gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_f64 diff --git a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 00000000..e991c027 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = int8_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 00000000..eae1cb10 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = uint8_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 00000000..a0753cda --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu new file mode 100644 index 00000000..ad153ba3 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index a849b593..eed780cf 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -103,16 +103,17 @@ struct TestbedUniversal { double scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; int bits_output = cutlass::sizeof_bits::value; + bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; if (bits_input == 1) { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = is_unsigned_int ? 4 : 2; + scope_min = is_unsigned_int ? 0 : -2; } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; + scope_max = is_unsigned_int ? 10 : 5; + scope_min = is_unsigned_int ? 0 : -5; } else { scope_max = 8; scope_min = -8; diff --git a/test/unit/gemm/warp/CMakeLists.txt b/test/unit/gemm/warp/CMakeLists.txt index 1415da43..0e14745b 100644 --- a/test/unit/gemm/warp/CMakeLists.txt +++ b/test/unit/gemm/warp/CMakeLists.txt @@ -37,6 +37,7 @@ cutlass_test_unit_add_executable( gemm_complex_sm80.cu gemm_sparse_sm80.cu gemm_gaussian_complex_sm80.cu + gemm_mixed_input_sm80.cu gemm_sm90.cu gemm_complex_sm90.cu wmma_sm70.cu diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu new file mode 100644 index 00000000..89d56e10 --- /dev/null +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= F16 * I8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::half_t; + using ElementB = int8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::half_t; + using ElementB = int8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= I8 * F16 + F32 (Upcast on Operand A) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = int8_t; + using ElementB = cutlass::half_t;; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = int8_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= F16 * U8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::half_t; + using ElementB = uint8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::half_t; + using ElementB = uint8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= U8 * F16 + F32 (Upcast on Operand A) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= B16 * U8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::bfloat16_t; + using ElementB = uint8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= B16 * U8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = uint8_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= B16 * I8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// F32 <= B16 * I8 + F32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) \ No newline at end of file diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index a11ebcf6..d4b00c92 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -229,6 +229,7 @@ cutlass_add_cutlass_library( src/reference/gemm_fp8in_fp32out.cu src/reference/gemm_fp32out.cu src/reference/gemm_fp_other.cu + src/reference/gemm_fp_mixed_input.cu src/reference/initialize_reference_operations.cu # cutlass reduction instances in cutlass library diff --git a/tools/library/src/reference/gemm_fp_mixed_input.cu b/tools/library/src/reference/gemm_fp_mixed_input.cu new file mode 100644 index 00000000..ea1c88ba --- /dev/null +++ b/tools/library/src/reference/gemm_fp_mixed_input.cu @@ -0,0 +1,138 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { + // half_t mixed with 8-bit integer input + make_gemm_real_canonical_layouts< + int8_t, + half_t, + half_t, + half_t, + half_t + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + half_t, + half_t, + half_t, + half_t + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + half_t, + half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + half_t, + half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + uint8_t, + half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + int8_t, + half_t, + float, + float + >(manifest); + + // bfloat16_t mixed with 8-bit integer input + make_gemm_real_canonical_layouts< + uint8_t, + bfloat16_t, + bfloat16_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + bfloat16_t, + bfloat16_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + uint8_t, + bfloat16_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + int8_t, + bfloat16_t, + float, + float + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index 15ce5228..cc92f91f 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -56,6 +56,7 @@ void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest); void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp_other(Manifest &manifest); +void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest); void initialize_conv2d_reference_operations(Manifest &manifest); void initialize_conv3d_reference_operations(Manifest &manifest); @@ -82,6 +83,7 @@ void initialize_reference_operations(Manifest &manifest) { initialize_gemm_reference_operations_fp32out(manifest); initialize_gemm_reference_operations_fp_other(manifest); + initialize_gemm_reference_operations_fp_mixed_input(manifest); } ///////////////////////////////////////////////////////////////////////////////////////////////////