Skip to content

Commit

Permalink
fMHA: Sync FW with xFormers (#828)
Browse files Browse the repository at this point in the history
* fMHA: Add support for bias+dropout in FW

* Remove 'getMaximumSharedMemoryPerBlockKb'

* fix comments

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
2 people authored and ttl10101 committed Feb 7, 2024
1 parent 702aa96 commit 89e90ad
Show file tree
Hide file tree
Showing 12 changed files with 998 additions and 253 deletions.
53 changes: 46 additions & 7 deletions examples/41_fused_multi_head_attention/debug_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,17 @@
#if 1
#define PRINT_WARP_ID 0
#define PRINT_LANE_ID 0
#define PRINT_T0_L0(msg, ...) \
#define PRINT_B0_T0(msg, ...) \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_T0(msg, ...) \
if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_TX_LX(msg, ...) \
for (int bx = 0; bx < gridDim.x; ++bx) { \
for (int by = 0; by < gridDim.y; ++by) { \
Expand Down Expand Up @@ -84,7 +89,7 @@
} \
}
#else
#define PRINT_T0_L0
#define PRINT_B0_T0
#define PRINT_TX_LX
#endif

Expand Down Expand Up @@ -124,7 +129,7 @@ constexpr __string_view __get_type_name() {

// Print a given array
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
name, \
int(start), \
Expand All @@ -141,7 +146,7 @@ constexpr __string_view __get_type_name() {
#define PRINT_FRAG_T0_L0(name, frag) \
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
Expand All @@ -150,7 +155,7 @@ constexpr __string_view __get_type_name() {
}
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
{ \
PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \
PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \
for (int _start = 0; _start < length; _start += incr) { \
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
} \
Expand All @@ -160,7 +165,7 @@ constexpr __string_view __get_type_name() {

// Print a 4x4 matrix
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \
name, \
int(start_x), \
Expand All @@ -187,9 +192,43 @@ constexpr __string_view __get_type_name() {
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)

#define PRINT_PROBLEM_SIZE(name, ps) \
PRINT_T0_L0( \
PRINT_B0_T0( \
"%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
name, \
int(ps.m()), \
int(ps.n()), \
int(ps.k()))

template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(
AccumT accum,
LaneOffsetT lane_offset,
int32_t num_rows,
int32_t num_cols) {
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
for (int row = 0; row < num_rows; ++row) {
for (int col = 0; col < num_cols; ++col) {
if (col % 32 == 0) {
if (is_main) {
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
}
__syncthreads();
}
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (row == accum_m && col == accum_n &&
(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
printf(" %6.1f", float(accum[idx]));
}
},
[&](int accum_m) {});
__syncthreads();
}
if (is_main) {
printf("\n");
}
}
}
12 changes: 6 additions & 6 deletions examples/41_fused_multi_head_attention/default_fmha_grouped.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@

#include "fmha_grouped.h"
#include "gemm_kernel_utils.h"
#include "find_default_mma.h"
#include "attention_scaling_coefs_updater.h"
#include "mma_from_smem.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -154,10 +153,10 @@ struct DefaultFMHAGrouped {
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
ElementAccumulator,
kWarpSize>::Updater;
kWarpSize>::Iterator;

static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");

Expand Down Expand Up @@ -240,7 +239,8 @@ struct DefaultFMHAGrouped {
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
typename MM0::AccumulatorSharedStorage,
false>; // kScaleOperandA

using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
Expand Down
143 changes: 132 additions & 11 deletions examples/41_fused_multi_head_attention/fmha_grouped.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,18 @@

#include "fmha_grouped_problem_visitor.h"
#include "gemm_kernel_utils.h"
#include "epilogue_rescale_output.h"
#include "gemm/mma_accum_lambda_iterator.h"
#include "epilogue/epilogue_rescale_output.h"


namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
}

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -128,6 +139,9 @@ struct FMHAGrouped {
static int const kQueriesPerBlock = ThreadblockShape::kM;
static int const kKeysPerBlock = ThreadblockShape::kN;

static constexpr bool kSupportsDropout = false;
static constexpr bool kSupportsBias = false;

/// Warp count (concept: GemmShape)
using WarpCount = typename MM1::WarpCount;
static int const kThreadsPerWarp = 32;
Expand Down Expand Up @@ -619,10 +633,10 @@ struct FMHAGrouped {

// Mask out last if causal
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
Expand All @@ -641,14 +655,11 @@ struct FMHAGrouped {
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
kIsFirst>(
accum_o,
accum,
mi,
Expand All @@ -659,7 +670,7 @@ struct FMHAGrouped {
warp_id(),
num_keys - iter_key_start,
iteratorC_tile_offset,
params.scale);
kSupportsBias ? 1.0f : params.scale);
}));
}));

Expand Down Expand Up @@ -838,6 +849,116 @@ struct FMHAGrouped {
problem_visitor.advance(gridDim.x);
}
}

template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
WarpIteratorC,
accum_t,
kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}

auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);

// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);

// Make sure we all share the update values for `mi`
__syncthreads();

if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (LambdaIterator::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,9 @@ public:
p.head_dim_value = options.head_size_v;
p.num_queries = options.seq_length;
p.num_keys = options.seq_length_kv;
p.causal = options.causal;
if (options.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
}

// All tensors are in BMHK shapes
p.q_strideH = options.head_size;
Expand All @@ -868,6 +870,7 @@ public:
p.q_strideB = p.q_strideM * options.seq_length;
p.k_strideB = p.k_strideM * options.seq_length_kv;
p.v_strideB = p.v_strideM * options.seq_length_kv;
p.o_strideM = p.head_dim_value * p.num_heads;
}

// launch kernel :)
Expand Down Expand Up @@ -1005,7 +1008,9 @@ int run_attention(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration
kSingleValueIteration,
false, // Supports dropout
false // Supports bias
>;

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/

#pragma once

#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
Expand Down
Loading

0 comments on commit 89e90ad

Please sign in to comment.