Skip to content

Commit

Permalink
Support for Mixed Input TensorOp (#1084)
Browse files Browse the repository at this point in the history
* Passing warp-level mixed input F16*(S8/U8) tests

* passing device-level mixed input F16*(S8/U8) tests

* add to profiler - I8 (111 TFLOPs), U (123 TFLOPs)

* fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs)

* Speedup reference compilation (REVERT THIS COMMIT)

* wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s)

* Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs

* BF16 * S8 (142 TFLOPs)

* Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16]

* rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast

* Add device-level test and profiler support for upcast on operand A

* Move shfl before the cvt and reduce #shfls by 1/2

* fix smem_usage calculation for mixed_input types

* uncomment the stuff (getting ready for merge)

* profiler changes and mixed-input reference

* mixed input reference are in a new file

* use platform instead of std

* comments and typo only

* Use CreateGemmOperator and delete CreateMixedInputGemmOperator

* copyright for new files

* rebase follow-up
  • Loading branch information
Manish Gupta authored and ttl10101 committed Feb 7, 2024
1 parent cc5e010 commit 96181da
Show file tree
Hide file tree
Showing 26 changed files with 2,065 additions and 14 deletions.
10 changes: 10 additions & 0 deletions include/cutlass/arch/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tag indicating the input data types are mixed and the narrower type is
/// upcasted to the wider type
struct OpMultiplyAddMixedInputUpcast {};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every F32 output element
struct OpMultiplyAddFastF32 {};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
struct OpMultiplyAddComplexFastF32 {};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Helper for determining whether staged accumulation should be used for a given operator
template <typename Operator>
struct UseStagedAccumulation {
Expand Down
67 changes: 67 additions & 0 deletions include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "cutlass/numeric_types.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"

Expand Down Expand Up @@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Element type of A matrix
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Element type of B matrix
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {


// Check if the ElementA and ElementB are of different data types
static_assert(!platform::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
ElementA, ElementB>::type;

// Operand datatypes in the internal MMA instruction - use the wider of the two data types
using MmaElementA = ElementOperand;
using MmaElementB = ElementOperand;
using MmaElementC = ElementC;

// Uses
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
GemmShape<16, 8, 16>,
32,
MmaElementA, cutlass::layout::RowMajor,
MmaElementB, cutlass::layout::ColumnMajor,
MmaElementC, cutlass::layout::RowMajor,
arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1> >;

// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};

} // namespace warp
} // namespace gemm
} // namespace cutlass
Expand Down
Loading

0 comments on commit 96181da

Please sign in to comment.