Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to enable running softmax in fp32 #71

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@


def grouped_max(block_max, batch_size, block_groups):
orig_dtype = block_max.dtype
if orig_dtype == torch.float16:
# fp16 index_reduce is not supported ATM
block_max = block_max.to(torch.float32)
group_max = torch.full([batch_size + 1, *block_max.shape[1:]], -math.inf,
dtype=block_max.dtype, device=block_max.device)
group_max = group_max.index_reduce_(0, block_groups, block_max, 'amax')
group_max = group_max.index_reduce_(0, block_groups, block_max, 'amax').to(orig_dtype)
group_max = group_max.index_select(0, block_groups)
return group_max

Expand All @@ -49,27 +53,37 @@ def block2batch(tensor, block_mapping, matmul_op=torch.matmul):

def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_size,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op):
# Normalize the attention scores
# When fp32_softmax is enabled attn is left in fp32 after Q@K
# We can return to native dtype after we renormalize and calculate the adjustments

# Normalize the attention scores and cast attn to native dtype
block_max = attn.amax(dim=-1, keepdim=True)
adjustment_target_shape = block_max.shape
attn = attn.sub(block_max)
attn = attn.exp()
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)
attn = matmul_av_op(attn, value)
block_max = block_max.squeeze()
block_sums = block_sums.squeeze()

# Calculate maximum of blocks that belong to the same sequences
# and cast adjustments to native dtype
group_max = grouped_max(block_max, batch_size, block_groups)
block_adjustment = (block_max - group_max).exp()
block_adjustment = block_adjustment.to(value.dtype)
sum_adjusted = block_sums.mul(block_adjustment)
# Sum block's sums that belongs to the same sequeneces

# Sum block's sums that belongs to the same sequences
group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op)
group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op)
sum_adjusted = sum_adjusted.view(*adjustment_target_shape)
group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape)
block_adjustment = block_adjustment.view(*adjustment_target_shape)

# For stability in case some of the sums have been zeroed out during block aggretation
group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted)

# Post processing for the attention scores
rescale = block_adjustment.div(group_sum_adjusted)
attn = attn.mul(rescale)
Expand Down Expand Up @@ -103,7 +117,7 @@ def pa(attn, value, block_groups, block_mapping, block_scales, batch_size,
pipelined_pa_enabled = 'True' if "index_reduce" in capabilities() else 'False'
pipelined_pa_enabled = os.environ.get('VLLM_PIPELINED_PA', pipelined_pa_enabled).lower() == 'true'
pa_impl = pipelined_pa if pipelined_pa_enabled else pa

fp32_softmax = os.environ.get('VLLM_FP32_SOFTMAX', 'False').lower() == 'true'

def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
Expand All @@ -126,7 +140,11 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
attn = matmul_qk_op(query, key)
if fp32_softmax:
attn = attn.float()
htcore.mark_step()
attn = attn + block_bias
attn = pa_impl(attn, value, block_groups, block_mapping, block_scales=block_scales,
batch_size=batch_size, matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op)
Expand Down Expand Up @@ -184,9 +202,13 @@ def prompt_attention(
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if fp32_softmax:
attn_weights = attn_weights.float()
htcore.mark_step()
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = attn_weights.add(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = attn_weights.to(query.dtype)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
Expand Down