From c7bca33199d75d7061d1f32d339f2c884234df7a Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Thu, 9 Jan 2025 12:18:52 +0200 Subject: [PATCH 1/2] Add flag to enable running softmax in fp32 --- vllm_hpu_extension/ops.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index e6726268..fca76abc 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -27,9 +27,13 @@ def grouped_max(block_max, batch_size, block_groups): + orig_dtype = block_max.dtype + if orig_dtype == torch.float16: + # fp16 index_reduce is not supported ATM + block_max = block_max.to(torch.float32) group_max = torch.full([batch_size + 1, *block_max.shape[1:]], -math.inf, dtype=block_max.dtype, device=block_max.device) - group_max = group_max.index_reduce_(0, block_groups, block_max, 'amax') + group_max = group_max.index_reduce_(0, block_groups, block_max, 'amax').to(orig_dtype) group_max = group_max.index_select(0, block_groups) return group_max @@ -54,6 +58,7 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s adjustment_target_shape = block_max.shape attn = attn.sub(block_max) attn = attn.exp() + attn = attn.to(value.dtype) block_sums = attn.sum(dim=-1, keepdim=True) attn = matmul_av_op(attn, value) block_max = block_max.squeeze() @@ -62,6 +67,7 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s group_max = grouped_max(block_max, batch_size, block_groups) block_adjustment = (block_max - group_max).exp() sum_adjusted = block_sums.mul(block_adjustment) + sum_adjusted = sum_adjusted.to(value.dtype) # Sum block's sums that belongs to the same sequeneces group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) @@ -72,6 +78,7 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) # Post processing for the attention scores rescale = block_adjustment.div(group_sum_adjusted) + rescale = rescale.to(attn.dtype) attn = attn.mul(rescale) return attn @@ -103,7 +110,7 @@ def pa(attn, value, block_groups, block_mapping, block_scales, batch_size, pipelined_pa_enabled = 'True' if "index_reduce" in capabilities() else 'False' pipelined_pa_enabled = os.environ.get('VLLM_PIPELINED_PA', pipelined_pa_enabled).lower() == 'true' pa_impl = pipelined_pa if pipelined_pa_enabled else pa - +fp32_softmax = os.environ.get('VLLM_FP32_SOFTMAX', 'False').lower() == 'true' def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_scales, block_groups, scale, matmul_qk_op, @@ -126,7 +133,11 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, else: key = key.transpose(2, 3) - attn = matmul_qk_op(query, key) + block_bias + attn = matmul_qk_op(query, key) + if fp32_softmax: + attn = attn.float() + htcore.mark_step() + attn = attn + block_bias attn = pa_impl(attn, value, block_groups, block_mapping, block_scales=block_scales, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op) @@ -184,9 +195,13 @@ def prompt_attention( if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) + if fp32_softmax: + attn_weights = attn_weights.float() + htcore.mark_step() if attn_bias is not None: - attn_weights.add_(attn_bias) + attn_weights = attn_weights.add(attn_bias) attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = attn_weights.to(query.dtype) attn_weights = matmul_av_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) From 970cb8cd5daafe592140d5fddd521331bd54c799 Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Fri, 10 Jan 2025 15:39:18 +0200 Subject: [PATCH 2/2] Tweak casts + add comments --- vllm_hpu_extension/ops.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index fca76abc..c30118de 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -53,7 +53,10 @@ def block2batch(tensor, block_mapping, matmul_op=torch.matmul): def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_size, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op): - # Normalize the attention scores + # When fp32_softmax is enabled attn is left in fp32 after Q@K + # We can return to native dtype after we renormalize and calculate the adjustments + + # Normalize the attention scores and cast attn to native dtype block_max = attn.amax(dim=-1, keepdim=True) adjustment_target_shape = block_max.shape attn = attn.sub(block_max) @@ -63,22 +66,26 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s attn = matmul_av_op(attn, value) block_max = block_max.squeeze() block_sums = block_sums.squeeze() + # Calculate maximum of blocks that belong to the same sequences + # and cast adjustments to native dtype group_max = grouped_max(block_max, batch_size, block_groups) block_adjustment = (block_max - group_max).exp() + block_adjustment = block_adjustment.to(value.dtype) sum_adjusted = block_sums.mul(block_adjustment) - sum_adjusted = sum_adjusted.to(value.dtype) - # Sum block's sums that belongs to the same sequeneces + + # Sum block's sums that belongs to the same sequences group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) sum_adjusted = sum_adjusted.view(*adjustment_target_shape) group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) block_adjustment = block_adjustment.view(*adjustment_target_shape) + # For stability in case some of the sums have been zeroed out during block aggretation group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) + # Post processing for the attention scores rescale = block_adjustment.div(group_sum_adjusted) - rescale = rescale.to(attn.dtype) attn = attn.mul(rescale) return attn