Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decoder MMHA kernel support INT8 SCALE_Q_INSTEAD_OF_K and SCALE_P_INS… #2085

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace kernels
#define MMHA_FP8_SCALE_P_INSTEAD_OF_V
#endif // !defined ENABLE_FP8

// Apply the INT8 scaling to Q instead of K.
#define MMHA_INT8_SCALE_Q_INSTEAD_OF_K
// Apply the INT8 scaling to P instead of V.
#define MMHA_INT8_SCALE_P_INSTEAD_OF_V

// Below are knobs to extend FP32 accumulation for higher FP16 accuracy

// Does not seem to affect the accuracy that much
Expand Down Expand Up @@ -959,8 +964,12 @@ inline __device__ void Logit_value_fma(
float logit = is_mask ? 0.f : reinterpret_cast<float*>(logits_smem)[0];
if constexpr (INT8_KV_CACHE)
{
#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V
out = fma(logit, v_vec, out);
#else
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
out = fma(logit, cast_to_float(v_vec_), out);
#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V
}
else if constexpr (FP8_KV_CACHE)
{
Expand All @@ -979,8 +988,12 @@ inline __device__ void Logit_value_fma(
Tk logit = is_mask ? Tk(0.f) : logits_smem[0];
if constexpr (INT8_KV_CACHE)
{
#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V
out = fma(logit, v_vec, out);
#else
V_vec_accum v_vec_ = mul<V_vec_accum, float, V_vec_m>(v_scale, v_vec);
out = fma(logit, v_vec_, out);
#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V
}
else if constexpr (FP8_KV_CACHE)
{
Expand Down Expand Up @@ -1312,9 +1325,18 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
static constexpr bool ENABLE_8BITS_K_CACHE = sizeof(TKcache) == 1;
static constexpr bool ENABLE_8BITS_KV_CACHE = sizeof(Tcache) == 1;
// FP8 KV Cache.
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
static constexpr bool FP8_K_CACHE = std::is_same<TKcache, __nv_fp8_e4m3>::value;
#else
static constexpr bool FP8_K_CACHE = false;
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
static constexpr bool FP8_KV_CACHE = std::is_same<Tcache, __nv_fp8_e4m3>::value;
// INT8 KV Cache.
#ifdef MMHA_INT8_SCALE_Q_INSTEAD_OF_K
static constexpr bool INT8_K_CACHE = std::is_same<TKcache, int8_t>::value;
#else
static constexpr bool INT8_K_CACHE = false;
#endif // MMHA_INT8_SCALE_Q_INSTEAD_OF_K
static constexpr bool INT8_KV_CACHE = std::is_same<Tcache, int8_t>::value;

// The size of a warp.
Expand Down Expand Up @@ -1734,21 +1756,19 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
{

// Store the Q values to shared memory.
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_K_CACHE)
if constexpr (FP8_K_CACHE || INT8_K_CACHE)
{
// There are many more elements from K than elements from Q so we pre-scale Q instead
// of scaling all the elements from K. It helps reduce the number of ops.
Qk_vec_k scaled_q;
zero(scaled_q);
if (is_valid_qk_vec)
{
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(k_scale_quant_orig, q);
scaled_q = mul<Qk_vec_k, T_scale, Qk_vec_k>(k_scale_quant_orig, q);
}
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
}
else
#endif
{
// Set padded Dh to 0 for the correctness of QK (when Dh != Dh_Max).
Qk_vec_k zero_q;
Expand Down Expand Up @@ -2012,13 +2032,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
// Compute the dot product between Q and K.
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_K_CACHE)
if constexpr (FP8_K_CACHE || INT8_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
}
else
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
{
if constexpr (ENABLE_8BITS_K_CACHE)
{
Expand Down Expand Up @@ -2158,13 +2176,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_K_CACHE)
if constexpr (FP8_K_CACHE || INT8_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
}
else
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
{
if constexpr (ENABLE_8BITS_K_CACHE)
{
Expand Down Expand Up @@ -2338,12 +2354,19 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
// Compute the sum.
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);

// Normalize the logits.
// Normalize the logits.
float logit_scale = 1.0f;
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f);
#else
float logit_scale = 1.f;
if constexpr (FP8_KV_CACHE) {
logit_scale = kv_scale_quant_orig_f;
}
#endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V
#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V
if constexpr (INT8_KV_CACHE) {
logit_scale = kv_scale_quant_orig_f;
}
#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V

float inv_sum = __fdividef(logit_scale, sum + 1.e-6f);

int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2431,6 +2431,15 @@ inline __device__ float4 mul(float4 a, int32_t b)
return fc;
}

///////////////////////////////////////////////////////////////////////////////////////////////

template <>
inline __device__ Float4_ mul(float4 a, int32_t b)
{
float4 fc = mul<float4, float4, int32_t>(a, b);
return reinterpret_cast<Float4_&>(fc);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float sum(float v)
Expand Down