Skip to content

Commit

Permalink
accuracy issue is fixed
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Dec 18, 2024
1 parent bc69ebb commit 6ddfcac
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
65 changes: 61 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,56 @@
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if attn_bias is not None:
attn_bias = attn_bias.expand(query.size(0), query_heads, query.size(2), key.size(2))
if attn_bias is not None or fsdpa_op is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unflatten(1, (kv_heads, -1))
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1'
# TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE:
key = ops.repeat_kv(key, int(query_heads // kv_heads))
value = ops.repeat_kv(value, int(query_heads // kv_heads))
if attn_bias is not None:
attn_bias = attn_bias.unflatten(1, (kv_heads, -1))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=True,
scale=scale, softmax_mode=softmax_mode, recompute_mode=recompute_mode,
valid_sequence_lengths=valid_seq_lengths, padding_side='right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights

class HPUAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -228,8 +278,10 @@ def forward(
if attn_metadata.is_prompt:
padded_shape = attn_metadata.input_tokens_padded_tensor
seq_lens_tensor_list = seq_lens_tensor.tolist()
padded_key_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=key.device, dtype=key.dtype)
padded_value_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=query.device, dtype=key.dtype)
padded_key_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size),
device=key.device, dtype=key.dtype)
padded_value_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size),
device=query.device, dtype=key.dtype)
start = 0
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
Expand Down Expand Up @@ -284,11 +336,16 @@ def forward(
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
elif enable_merged_prefill:
pass
else:
attn_bias = None

#self.fused_scaled_dot_product_attention = None
out = ops.prompt_attention(
if enable_merged_prefill:
prompt_attn_func = prompt_attention
else:
prompt_attn_func = ops.prompt_attention
out = prompt_attn_func(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
Expand Down
32 changes: 29 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names):
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.merged_prefill_attn_mask_compute = os.getenv('VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE',
'1').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.layer_names = layer_names
Expand Down Expand Up @@ -287,6 +289,28 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
return attn_metadata

def _set_merged_attn_bias(self, attn_metadata, batch_size, max_seq_len, device,):# create a 2D causal attn mask to ensure I can only attend to the past
if attn_metadata is None or not attn_metadata.is_prompt:
return attn_metadata
if not self.merged_prefill_attn_mask_compute:
return attn_metadata
#TODO: Support batch_size > 1
seq_lens = attn_metadata.seq_lens_tensor.tolist()
causal_attn_mask_tensor = torch.ones((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
start = 0
for i in range(batch_size):
for seq_len in seq_lens:
# create triangular mask for each sequence
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1)
causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].copy_(causal_mask)
start += seq_len
causal_attn_mask_tensor = (torch.zeros_like(causal_attn_mask_tensor, device=device, dtype=self.dtype).masked_fill_(
causal_attn_mask_tensor, -10000)) # should be math(-inf) but -10000 is used for numerical stability
causal_attn_mask_tensor = causal_attn_mask_tensor.view(causal_attn_mask_tensor.shape[0], 1, causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2])

attn_metadata = attn_metadata._replace(attn_bias=causal_attn_mask_tensor)
return attn_metadata

def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = torch.arange(0,
self.block_size,
Expand Down Expand Up @@ -340,7 +364,9 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt):

def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill:
attn_metadata = self._set_merged_attn_bias(attn_metadata, batch_size, seq_len, device)
elif attn_metadata.is_prompt:
attn_metadata = self._set_attn_bias(attn_metadata, batch_size,
seq_len, device, dtype)
else:
Expand Down Expand Up @@ -886,8 +912,8 @@ def _prepare_prompt(
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
slot_mapping[-1].append(slot)

max_query_len = max(query_lens)
real_num_seqs = len(query_lens)
assert max_query_len > 0
Expand Down

0 comments on commit 6ddfcac

Please sign in to comment.