diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h index 7e91a723..aafc62d6 100644 --- a/examples/41_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -50,12 +50,17 @@ #if 1 #define PRINT_WARP_ID 0 #define PRINT_LANE_ID 0 -#define PRINT_T0_L0(msg, ...) \ +#define PRINT_B0_T0(msg, ...) \ if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ threadIdx.z == 0) { \ printf(msg "\n", ##__VA_ARGS__); \ } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } #define PRINT_TX_LX(msg, ...) \ for (int bx = 0; bx < gridDim.x; ++bx) { \ for (int by = 0; by < gridDim.y; ++by) { \ @@ -84,7 +89,7 @@ } \ } #else -#define PRINT_T0_L0 +#define PRINT_B0_T0 #define PRINT_TX_LX #endif @@ -124,7 +129,7 @@ constexpr __string_view __get_type_name() { // Print a given array #define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ - PRINT_T0_L0( \ + PRINT_B0_T0( \ "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ name, \ int(start), \ @@ -141,7 +146,7 @@ constexpr __string_view __get_type_name() { #define PRINT_FRAG_T0_L0(name, frag) \ { \ auto typeStr = __get_type_name(); \ - PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ for (int _start = 0; _start < frag.size(); _start += 8) { \ PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ } \ @@ -150,7 +155,7 @@ constexpr __string_view __get_type_name() { } #define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ { \ - PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ for (int _start = 0; _start < length; _start += incr) { \ PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ } \ @@ -160,7 +165,7 @@ constexpr __string_view __get_type_name() { // Print a 4x4 matrix #define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ - PRINT_T0_L0( \ + PRINT_B0_T0( \ "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ name, \ int(start_x), \ @@ -187,9 +192,43 @@ constexpr __string_view __get_type_name() { PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) #define PRINT_PROBLEM_SIZE(name, ps) \ - PRINT_T0_L0( \ + PRINT_B0_T0( \ "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ name, \ int(ps.m()), \ int(ps.n()), \ int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/examples/41_fused_multi_head_attention/default_fmha_grouped.h index 5a1ed5c0..b0acc943 100644 --- a/examples/41_fused_multi_head_attention/default_fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -50,9 +50,8 @@ #include "fmha_grouped.h" #include "gemm_kernel_utils.h" -#include "find_default_mma.h" -#include "attention_scaling_coefs_updater.h" -#include "mma_from_smem.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -154,10 +153,10 @@ struct DefaultFMHAGrouped { using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; using Mma = typename DefaultMma::ThreadblockMma; - using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, ElementAccumulator, - kWarpSize>::Updater; + kWarpSize>::Iterator; static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); @@ -240,7 +239,8 @@ struct DefaultFMHAGrouped { using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage>; + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; diff --git a/examples/41_fused_multi_head_attention/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h similarity index 100% rename from examples/41_fused_multi_head_attention/epilogue_pipelined.h rename to examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h diff --git a/examples/41_fused_multi_head_attention/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h similarity index 100% rename from examples/41_fused_multi_head_attention/epilogue_rescale_output.h rename to examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h diff --git a/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h similarity index 100% rename from examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h rename to examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index 58f47d74..f71ca22b 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -48,7 +48,18 @@ #include "fmha_grouped_problem_visitor.h" #include "gemm_kernel_utils.h" -#include "epilogue_rescale_output.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "epilogue/epilogue_rescale_output.h" + + +namespace { + static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -128,6 +139,9 @@ struct FMHAGrouped { static int const kQueriesPerBlock = ThreadblockShape::kM; static int const kKeysPerBlock = ThreadblockShape::kN; + static constexpr bool kSupportsDropout = false; + static constexpr bool kSupportsBias = false; + /// Warp count (concept: GemmShape) using WarpCount = typename MM1::WarpCount; static int const kThreadsPerWarp = 32; @@ -619,10 +633,10 @@ struct FMHAGrouped { // Mask out last if causal if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { - auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( lane_id(), warp_id(), iteratorC_tile_offset); int32_t last_col; - MM0::ScalingCoefsUpdater::iterateRows( + MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; @@ -641,14 +655,11 @@ struct FMHAGrouped { kFullColumns, ([&] { // Update `mi` from accum stored in registers - // Also updates `accum` with accum[i] <- - // exp(accum[i] * scale - // - mi) - MM0::ScalingCoefsUpdater::update< - kQueriesPerBlock, + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax< + typename MM0::Mma::Operator::IteratorC, kFullColumns, - kIsFirst, - kKeepOutputInRF>( + kIsFirst>( accum_o, accum, mi, @@ -659,7 +670,7 @@ struct FMHAGrouped { warp_id(), num_keys - iter_key_start, iteratorC_tile_offset, - params.scale); + kSupportsBias ? 1.0f : params.scale); })); })); @@ -838,6 +849,116 @@ struct FMHAGrouped { problem_visitor.advance(gridDim.x); } } + + template < + typename WarpIteratorC, + bool kFullColumns, + bool kIsFirst> + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kThreadsPerWarp>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index 45d6813a..d3ffef76 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -856,7 +856,9 @@ public: p.head_dim_value = options.head_size_v; p.num_queries = options.seq_length; p.num_keys = options.seq_length_kv; - p.causal = options.causal; + if (options.causal) { + p.custom_mask_type = Attention::CausalFromTopLeft; + } // All tensors are in BMHK shapes p.q_strideH = options.head_size; @@ -868,6 +870,7 @@ public: p.q_strideB = p.q_strideM * options.seq_length; p.k_strideB = p.k_strideM * options.seq_length_kv; p.v_strideB = p.v_strideM * options.seq_length_kv; + p.o_strideM = p.head_dim_value * p.num_heads; } // launch kernel :) @@ -1005,7 +1008,9 @@ int run_attention(Options& options) { true, // Memory is aligned kQueriesPerBlock, kKeysPerBlock, - kSingleValueIteration + kSingleValueIteration, + false, // Supports dropout + false // Supports bias >; // diff --git a/examples/41_fused_multi_head_attention/find_default_mma.h b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h similarity index 99% rename from examples/41_fused_multi_head_attention/find_default_mma.h rename to examples/41_fused_multi_head_attention/gemm/find_default_mma.h index 9c62c8c1..2e6b35b6 100644 --- a/examples/41_fused_multi_head_attention/find_default_mma.h +++ b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -42,6 +42,8 @@ This is really only for the FastF32 case - aka using TensorCores with fp32. */ +#pragma once + #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" diff --git a/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h similarity index 71% rename from examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h rename to examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h index 4de04ef9..ad2b7e02 100644 --- a/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -36,137 +36,15 @@ #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" #include "cutlass/matrix_shape.h" -#include "gemm_kernel_utils.h" -namespace { - -static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { - // source: https://stackoverflow.com/a/51549250 - return (value >= 0) - ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); -} -} // namespace - -/* Iterates on the accumulator and corresponding position on result matrix - -(1) Update `mi[r]` to the max value of the row `r` -(2) In a second iteration do the following: - (a) accum <- exp(accum - mi) - (b) m_prime <- exp(m_prime - mi) - (c) s_prime <- s_prime * m_prime + sum(accum) - -All of this is done on registers, before we store all of this -on shared memory for the next matmul with Value. - -We have multiple implementations, because each configuration has a different way -of iterating in the accumulators. +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. */ -template -struct RegisterOps { - template < - int kQueriesPerBlock, - bool kFullColumns, - bool kIsFirst, - bool kKeepOutputInRF> - CUTLASS_DEVICE static void update( - typename T::Fragment& frag_o, // output so far - typename T::Fragment& frag, - cutlass::Array& mi, - cutlass::Array& m_prime, - cutlass::Array& s_prime, - int8_t lane_id, - int8_t thread_id, - int8_t warp_id, - int16_t max_col, - typename T::TensorCoord const& tile_offset, - float scaling) { - // Convert to `accum_t` (rather than double) - constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E - if (!kIsFirst) { - if (thread_id < kQueriesPerBlock) { - m_prime[thread_id] = mi[thread_id]; - } - __syncthreads(); - } - - auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); - - // First update `mi` to the max per-row - { - accum_t max; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { - max = -cutlass::platform::numeric_limits::infinity(); - }, - [&](int accum_m, int accum_n, int idx) { - if (kFullColumns || accum_n < max_col) { - max = cutlass::fast_max(max, frag[idx]); - } - }, - [&](int accum_m) { - // Having 4x atomicMax seems faster than reduce within warp - // first... - atomicMaxFloat(&mi[accum_m], max * scaling); - }); - } - frag = cutlass::multiplies()(scaling * kLog2e, frag); - - // Make sure we all share the update values for `mi` - __syncthreads(); - - if (thread_id < kQueriesPerBlock) { - auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); - m_prime[thread_id] = m_prime_exp; - s_prime[thread_id] *= m_prime_exp; - } - __syncthreads(); // Update output fragments - if (kKeepOutputInRF && !kIsFirst) { - accum_t mp; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { mp = m_prime[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, - [&](int accum_m) {}); - __syncthreads(); - } - // Update accum_m, accum_n, ... - { - accum_t mi_row, total_row; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { - frag[idx] = (kFullColumns || accum_n < max_col) - ? exp2f(frag[idx] - mi_row) - : accum_t(0.0); - }, - [&](int accum_m) {}); - BASE::iterateRows( - lane_offset, - [&](int accum_m) { total_row = 0.0; }, - [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, - [&](int accum_m) { - if (BASE::reduceSameRow( - lane_id, total_row, [](accum_t a, accum_t b) { - return a + b; - })) { - atomicAdd(&s_prime[accum_m], total_row); - } - }); - } - } -}; - template -struct AttentionScalingCoefsUpdaterSm80 - : RegisterOps< - AttentionScalingCoefsUpdaterSm80, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSm80 { static_assert( cutlass::platform:: is_same::value, @@ -239,12 +117,7 @@ struct AttentionScalingCoefsUpdaterSm80 }; template -struct AttentionScalingCoefsUpdaterVolta - : RegisterOps< - AttentionScalingCoefsUpdaterVolta, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSm70 { static_assert( cutlass::platform:: is_same::value, @@ -357,12 +230,7 @@ struct AttentionScalingCoefsUpdaterVolta }; template -struct AttentionScalingCoefsUpdaterSimt - : RegisterOps< - AttentionScalingCoefsUpdaterSimt, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSimt { using Policy = typename T::Policy; using Iterations = typename T::Iterations; using Element = typename T::Element; @@ -436,11 +304,11 @@ struct AttentionScalingCoefsUpdaterSimt }; template -struct DefaultAttentionScalingCoefsUpdater; +struct DefaultMmaAccumLambdaIterator; // Simt template -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaSimtTileIterator< S, cutlass::gemm::Operand::kC, @@ -451,7 +319,7 @@ struct DefaultAttentionScalingCoefsUpdater< 1>, accum_t, kWarpSize> { - using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< S, cutlass::gemm::Operand::kC, accum_t, @@ -459,13 +327,12 @@ struct DefaultAttentionScalingCoefsUpdater< P, 1, 1>; - using Updater = - AttentionScalingCoefsUpdaterSimt; + using Iterator = AccumLambdaIteratorSimt; }; // TensorOp - Volta template -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< S1, accum_t, @@ -474,15 +341,14 @@ struct DefaultAttentionScalingCoefsUpdater< cutlass::MatrixShape<1, 1>>, accum_t, kWarpSize> { - using Iterator = + using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< S1, accum_t, cutlass::layout::RowMajor, S2, cutlass::MatrixShape<1, 1>>; - using Updater = - AttentionScalingCoefsUpdaterVolta; + using Iterator = AccumLambdaIteratorSm70; }; // TensorOp - Sm75+ @@ -492,7 +358,7 @@ template < typename S3, typename accum_t, int kWarpSize> -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< S1, accum_t, @@ -501,13 +367,12 @@ struct DefaultAttentionScalingCoefsUpdater< S3>, accum_t, kWarpSize> { - using Iterator = + using WarpIterator = typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< S1, accum_t, cutlass::layout::RowMajor, S2, S3>; - using Updater = - AttentionScalingCoefsUpdaterSm80; + using Iterator = AccumLambdaIteratorSm80; }; diff --git a/examples/41_fused_multi_head_attention/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h similarity index 84% rename from examples/41_fused_multi_head_attention/mma_from_smem.h rename to examples/41_fused_multi_head_attention/gemm/mma_from_smem.h index d2ceaf02..993af37a 100644 --- a/examples/41_fused_multi_head_attention/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -43,22 +43,26 @@ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/vector_iterator.h" -#include "attention_scaling_coefs_updater.h" +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" #include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -#include "epilogue_thread_apply_logsumexp.h" -#include "gemm_kernel_utils.h" -#include "iterators/make_residual_last.h" -#include "iterators/transpose_warp_iterator.h" -#include "iterators/warp_iterator_from_smem.h" namespace cutlass { namespace gemm { @@ -246,6 +250,78 @@ class MmaBaseFromSharedMemory { : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // using size 1 is kind of a hack to get around arrays of zero-sized objects + // not being allowed. the compiler is probably smart enough to wipe it out + // anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + //////////////////////////////////////////////////////////////////////////////// // Taken from // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -259,6 +335,10 @@ template < // BEGIN smem /// Iterates over the intermediate accumulator tile in shared memory typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, // Accumulator type typename AccumulatorSharedStorage, // END smem @@ -297,6 +377,15 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory using ElementC = ElementC_; ///< Data type of accumulator matrix @@ -333,8 +422,20 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< private: using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + using WarpFragmentB = typename Operator::FragmentB; + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + protected: // /// Iterator to write threadblock-scoped tile of A operand to shared memory // SmemIteratorA smem_iterator_A_; @@ -346,7 +447,46 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< /// accumulator tile WarpIteratorA warp_tile_iterator_A_; + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + /// Construct from tensor references CUTLASS_DEVICE MmaPipelinedFromSharedMemory( @@ -429,19 +569,26 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< __syncthreads(); + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + // Pair of fragments used to overlap shared memory loads and math // instructions WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; WarpFragmentB warp_frag_B[2]; warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); warp_frag_B[0].clear(); this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); this->warp_tile_iterator_B_.load(warp_frag_B[0]); ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; ++this->warp_tile_iterator_B_; Operator warp_mma; @@ -503,9 +650,12 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< (warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; ++this->warp_tile_iterator_B_; if (warp_mma_k == 0) { @@ -521,7 +671,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< warp_mma( accum, - warp_frag_A[warp_mma_k % 2], + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), warp_frag_B[warp_mma_k % 2], accum); } @@ -541,6 +692,10 @@ template < typename Shape1_, /// Iterates over the intermediate accumulator tile in shared memory typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, // Accumulator type typename AccumulatorSharedStorage, /// Iterates over tiles of B operand in global memory @@ -580,7 +735,14 @@ class MmaMultistageFromSharedMemory using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate ///< accumulator tile in shared memory - + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; ///< Data type of accumulator matrix using ElementC = ElementC_; ///< Layout of accumulator matrix @@ -628,10 +790,20 @@ class MmaMultistageFromSharedMemory private: using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; using WarpLoadedFragmentB1 = typename Operator1::FragmentB; using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + private: // // Data members @@ -641,12 +813,54 @@ class MmaMultistageFromSharedMemory /// accumulator tile WarpIteratorA1 warp_tile_iterator_A1_; + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB1 smem_iterator_B1_; bool prologue_done_; public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + /// Construct from tensor references CUTLASS_DEVICE MmaMultistageFromSharedMemory( @@ -842,9 +1056,13 @@ class MmaMultistageFromSharedMemory cutlass::arch::cp_async_wait(); __syncthreads(); + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + // Pair of fragments used to overlap shared memory loads and math // instructions WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; @@ -854,6 +1072,9 @@ class MmaMultistageFromSharedMemory warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); ++warp_tile_iterator_A1_; + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); ++this->warp_tile_iterator_B_; @@ -864,7 +1085,8 @@ class MmaMultistageFromSharedMemory warp_mma1.transform( warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], - warp_loaded_frag_A1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), warp_loaded_frag_B1[0]); // tf32x3 kernels use staging accumulation. warp_mma uses a temporary @@ -909,17 +1131,22 @@ class MmaMultistageFromSharedMemory warp_mma_k < Base::kWarpGemmIterations1 - 1) { warp_tile_iterator_A1_.load( warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load( warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; ++this->warp_tile_iterator_B_; if (warp_mma_k > 0) warp_mma1.transform( warp_transformed_frag_A1[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], - warp_loaded_frag_A1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), warp_loaded_frag_B1[warp_mma_k % 2]); if (platform::is_same< @@ -1009,7 +1236,9 @@ class MmaMultistageFromSharedMemory warp_mma1.transform( warp_transformed_frag_A1[(warp_mma_k + 1) % 2], warp_transformed_frag_B1[(warp_mma_k + 1) % 2], - warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } } @@ -1119,6 +1348,9 @@ struct DefaultWarpIteratorAFromSharedMemory< template < typename Mma_, typename AccumulatorSharedStorage, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA = false> struct DefaultMmaFromSharedMemory; @@ -1151,6 +1383,9 @@ template < /// Transformation applied to B operand typename TransformB_, typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaPipelined< @@ -1165,6 +1400,7 @@ struct DefaultMmaFromSharedMemory< TransformA_, TransformB_>, AccumulatorSharedStorage_, + kScaleOperandA, kTransposeA> { static constexpr int kWarpSize = 32; using SmemAccumulatorLayout = cutlass::layout::RowMajor; @@ -1198,6 +1434,7 @@ struct DefaultMmaFromSharedMemory< using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< Shape_, WarpIteratorA, + kScaleOperandA, AccumulatorSharedStorage_, IteratorB, SmemIteratorB_, @@ -1238,6 +1475,9 @@ template < /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear, typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, bool kTransposeA> struct DefaultMmaFromSharedMemory< MmaMultistage< @@ -1254,6 +1494,7 @@ struct DefaultMmaFromSharedMemory< Stages, SharedMemoryClear>, AccumulatorSharedStorage_, + kScaleOperandA, kTransposeA> { static constexpr int kWarpSize = 32; @@ -1301,6 +1542,7 @@ struct DefaultMmaFromSharedMemory< typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< Shape_, WarpIteratorA, + kScaleOperandA, AccumulatorSharedStorage_, IteratorB, SmemIteratorB_, @@ -1637,18 +1879,17 @@ struct B2bGemm< // NOTE: accum is attn.T // TODO: Optimize for each architecture static constexpr int WarpSize = 32; - using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< - IteratorC, - accum_t, - WarpSize>::Updater; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; auto lane_offset = - RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); cutlass::Array lse_prefetched; lse_prefetched.clear(); int rowIdx = 0; int colIdx = 0; - RegistersIter::iterateRows( + AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { ++rowIdx; @@ -1777,18 +2018,17 @@ struct B2bGemm< // NOTE: accum is attn.T // TODO: Optimize for each architecture static constexpr int WarpSize = 32; - using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< - IteratorC, - accum_t, - WarpSize>::Updater; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; auto lane_offset = - RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); cutlass::Array lse_prefetched; lse_prefetched.clear(); int rowIdx = 0; int colIdx = 0; - RegistersIter::iterateRows( + AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { ++rowIdx; diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 48c47edb..0564bcef 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -29,16 +29,26 @@ * **************************************************************************************************/ +#pragma once + +#ifdef HAS_PYTORCH +#include +#include +#endif + +#include #include #include #include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" #include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" -#include "attention_scaling_coefs_updater.h" #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" @@ -54,11 +64,12 @@ #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "debug_utils.h" -#include "epilogue_pipelined.h" -#include "epilogue_rescale_output.h" -#include "find_default_mma.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" #include "gemm_kernel_utils.h" -#include "mma_from_smem.h" +#include "transform/tile_smem_loader.h" #include @@ -73,6 +84,12 @@ constexpr int getWarpsPerSm() { ? 16 : 12); } +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} } // namespace template < @@ -83,10 +100,20 @@ template < // If Q/K/V are correctly aligned in memory and we can run a fast kernel bool isAligned_, int kQueriesPerBlock, - int kKeysPerBlock, - bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` - > + int kKeysPerBlock_, + bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true> struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; using accum_t = float; using lse_scalar_t = float; @@ -95,7 +122,11 @@ struct AttentionKernel { // Using `accum_t` improves perf on f16 at the cost of // numerical errors using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kSingleValueIteration_; static constexpr int32_t kAlignLSE = 32; // block size of backward static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits::value == 16; @@ -117,10 +148,15 @@ struct AttentionKernel { struct Params { // Input tensors scalar_t* query_ptr; // [num_queries, num_heads, head_dim] - scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] - int32_t* cu_seqlens_q_ptr = nullptr; - int32_t* cu_seqlens_k_ptr = nullptr; + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* causal_diagonal_ptr = nullptr; + int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; // Output tensors output_t* output_ptr; // [num_queries, num_heads, head_dim_value] @@ -137,26 +173,38 @@ struct AttentionKernel { int32_t num_queries; int32_t num_keys; - bool causal; + uint8_t custom_mask_type = NoCustomMask; int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; // Everything below is only used in `advance_to_block` // and shouldn't use registers int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; + int32_t bias_strideH = 0; + int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; + int32_t bias_strideB = 0; + int32_t num_batches; int32_t num_heads; - CUTLASS_HOST_DEVICE int32_t o_strideM() const { - return head_dim_value * num_heads; - } + // dropout + bool use_dropout; + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs; +#endif + // Moves pointers to what we should process // Returns "false" if there is no work to do CUTLASS_DEVICE bool advance_to_block() { @@ -166,18 +214,33 @@ struct AttentionKernel { auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + int64_t q_start, k_start; // Advance to current batch - in case of different sequence lengths - if (cu_seqlens_q_ptr != nullptr) { - assert(cu_seqlens_k_ptr != nullptr); - cu_seqlens_q_ptr += batch_id; - cu_seqlens_k_ptr += batch_id; - q_start = cu_seqlens_q_ptr[0]; - k_start = cu_seqlens_k_ptr[0]; - int64_t q_next_start = cu_seqlens_q_ptr[1]; - int64_t k_next_start = cu_seqlens_k_ptr[1]; + if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + num_queries = q_next_start - q_start; - num_keys = k_next_start - k_start; + num_keys = k_end - k_start; if (query_start >= num_queries) { return false; @@ -186,9 +249,10 @@ struct AttentionKernel { query_ptr += batch_id * q_strideB; key_ptr += batch_id * k_strideB; value_ptr += batch_id * v_strideB; - output_ptr += int64_t(batch_id * num_queries) * o_strideM(); + output_ptr += int64_t(batch_id * num_queries) * o_strideM; if (output_accum_ptr != nullptr) { - output_accum_ptr += int64_t(batch_id * num_queries) * o_strideM(); + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); } q_start = 0; k_start = 0; @@ -197,42 +261,84 @@ struct AttentionKernel { // Advance to the current batch / head / query_start query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; key_ptr += k_start * k_strideM + head_id * k_strideH; + value_ptr += k_start * v_strideM + head_id * v_strideH; - output_ptr += int64_t(q_start + query_start) * o_strideM() + - head_id * head_dim_value; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } if (output_accum_ptr != nullptr) { - output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + head_id * head_dim_value; } else { // Accumulate directly in the destination buffer (eg for f32) output_accum_ptr = (accum_t*)output_ptr; } + if (logsumexp_ptr != nullptr) { // lse[batch_id, head_id, query_start] logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; } - num_queries -= query_start; - if (causal) { + // Custom masking + if (causal_diagonal_ptr) { + causal_diagonal_offset = causal_diagonal_ptr[batch_id]; + } + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset += num_keys - num_queries; + } + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations num_keys = cutlass::fast_min( - int32_t(query_start + kQueriesPerBlock), num_keys); + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); } + + num_queries -= query_start; num_batches = 0; // no longer used after + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + // Make sure the compiler knows these variables are the same on all // the threads of the warp. query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } output_ptr = warp_uniform(output_ptr); output_accum_ptr = warp_uniform(output_accum_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); num_queries = warp_uniform(num_queries); num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); head_dim = warp_uniform(head_dim); head_dim_value = warp_uniform(head_dim_value); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); return true; } @@ -242,6 +348,7 @@ struct AttentionKernel { num_heads, num_batches); } + __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); } @@ -296,16 +403,24 @@ struct AttentionKernel { using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; using Mma = typename DefaultMma::ThreadblockMma; - using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, accum_t, - kWarpSize>::Updater; + kWarpSize>::Iterator; static_assert( MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK == kNumWarpsPerBlock, ""); + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< @@ -367,7 +482,8 @@ struct AttentionKernel { using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage>; + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; @@ -404,7 +520,10 @@ struct AttentionKernel { struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; typename MM1::SharedStorageMM1 mm1; }; @@ -423,7 +542,10 @@ struct AttentionKernel { struct SharedStorageEpilogueInLoop : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; typename MM1::SharedStorageMM1 mm1; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; @@ -448,6 +570,18 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } XFORMERS_CHECK( p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); XFORMERS_CHECK( @@ -460,6 +594,12 @@ struct AttentionKernel { p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); XFORMERS_CHECK( p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask, + "`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); return true; } @@ -472,8 +612,8 @@ struct AttentionKernel { SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); auto& m_prime = shared_storage.m_prime; auto& s_prime = shared_storage.s_prime; - [[maybe_unused]] auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); if (thread_id() < kQueriesPerBlock) { @@ -488,7 +628,7 @@ struct AttentionKernel { auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { using OutputTileIterator = typename MM1::OutputTileIterator; return OutputTileIterator( - typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, p.output_ptr, typename OutputTileIterator::TensorCoord{ p.num_queries, p.head_dim_value}, @@ -500,7 +640,8 @@ struct AttentionKernel { typename MM1::OutputTileIteratorAccum { using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; return OutputTileIteratorAccum( - typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, p.output_accum_ptr, typename OutputTileIteratorAccum::TensorCoord{ p.num_queries, p.head_dim_value}, @@ -508,6 +649,27 @@ struct AttentionKernel { {0, col}); }; +#ifdef HAS_PYTORCH + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } +#endif + // Iterate through keys for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; iter_key_start += kKeysPerBlock) { @@ -600,16 +762,65 @@ struct AttentionKernel { (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + (my_warp_id / MM0::Mma::WarpCount::kM)}; + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + // Mask out last if causal - if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { auto query_start = blockIdx.x * kQueriesPerBlock; - auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( lane_id(), warp_id(), iteratorC_tile_offset); int32_t last_col; - MM0::ScalingCoefsUpdater::iterateRows( + MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { - last_col = query_start + accum_m - iter_key_start; + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; }, [&](int accum_m, int accum_n, int idx) { if (accum_n > last_col) { @@ -625,14 +836,11 @@ struct AttentionKernel { kFullColumns, ([&] { // Update `mi` from accum stored in registers - // Also updates `accum` with accum[i] <- - // exp(accum[i] * scale - // - mi) - MM0::ScalingCoefsUpdater::update< - kQueriesPerBlock, + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax< + typename MM0::Mma::Operator::IteratorC, kFullColumns, - kIsFirst, - kKeepOutputInRF>( + kIsFirst>( accum_o, accum, mi, @@ -643,7 +851,7 @@ struct AttentionKernel { warp_id(), p.num_keys - iter_key_start, iteratorC_tile_offset, - p.scale); + kSupportsBias ? 1.0f : p.scale); })); })); @@ -659,6 +867,69 @@ struct AttentionKernel { __syncthreads(); +#ifdef HAS_PYTORCH + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } +#endif + // // MATMUL: Attn . V // Run the matmul `attn @ V` for a block of attn and V. @@ -830,6 +1101,116 @@ struct AttentionKernel { } } + template < + typename WarpIteratorC, + bool kFullColumns, + bool kIsFirst> + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } + static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; } @@ -849,3 +1230,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) } AK::attention_kernel(p); } + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 00000000..345bc5bb --- /dev/null +++ b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +};