Skip to content

Commit

Permalink
Enable merged prefill
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Dec 18, 2024
1 parent 9a210fc commit bc69ebb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
62 changes: 47 additions & 15 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
enable_merged_prefill: bool = False
input_tokens_padded_tensor: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -213,6 +215,7 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
if block_offsets is None:
Expand All @@ -221,20 +224,48 @@ def forward(
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)
if enable_merged_prefill:
if attn_metadata.is_prompt:
padded_shape = attn_metadata.input_tokens_padded_tensor
seq_lens_tensor_list = seq_lens_tensor.tolist()
padded_key_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=key.device, dtype=key.dtype)
padded_value_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=query.device, dtype=key.dtype)
start = 0
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
for i in range(padded_shape[0]):
padded_key_tensor[i, :seq_lens_tensor_list[i]].copy_(key[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True)
padded_value_tensor[i, :seq_lens_tensor_list[i]].copy_(value[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True)
start = start + seq_lens_tensor_list[i]
# shape will be [batch_size * entire_seq_len, num_kv_heads, head_size]
# then reshape it to [n_blocks, block_size, num_kv_heads * head_size]
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))
seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0)
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices,
block_offsets)
else:
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
seq_lens_tensor_merged = seq_lens_tensor
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -256,6 +287,7 @@ def forward(
else:
attn_bias = None

#self.fused_scaled_dot_product_attention = None
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
Expand All @@ -266,7 +298,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=seq_lens_tensor,
valid_seq_lengths=seq_lens_tensor_merged,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
else:
Expand Down
27 changes: 19 additions & 8 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,33 +1143,36 @@ def _prepare_prompt_merged(
real_num_seqs = len(total_query_lens)
assert max_query_len > 0

print(">>>> seq_lens", seq_lens, "max_query_len", max_query_len, "real_num_seqs", real_num_seqs)

max_prompt_len = max(
merged_prompt_len = max(
self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)),
self.block_size)
max_prompt_len = max(
self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)),
self.block_size)

prefix_block_list_tensor = None

input_tokens_tensor = make_tensor_with_pad(input_tokens_merged,
max_len=max_prompt_len,
max_len=merged_prompt_len,
pad=0,
dtype=torch.long,
device='cpu')

input_positions = make_tensor_with_pad(input_positions_merged,
max_len=max_prompt_len,
max_len=merged_prompt_len,
pad=0,
dtype=torch.long,
device='cpu')

input_tokens_padded_tensor = torch.tensor([len(seq_lens), max_prompt_len], dtype=torch.long, device='cpu')

slot_mapping = make_tensor_with_pad(slot_mapping_merged,
slot_mapping = make_tensor_with_pad(slot_mapping,
max_len=max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device='cpu')

seq_lens_tensor = torch.tensor(total_seq_lens,
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device='cpu')

Expand All @@ -1189,9 +1192,11 @@ def _prepare_prompt_merged(
seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True)
context_lens_tensor = context_lens_tensor.to(self.device,
non_blocking=True)

input_tokens_padded_tensor = input_tokens_padded_tensor.to(self.device,
non_blocking=True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
enable_merged_prefill=True,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
Expand All @@ -1207,6 +1212,7 @@ def _prepare_prompt_merged(
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
input_tokens_padded_tensor=input_tokens_padded_tensor,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
)
Expand Down Expand Up @@ -1593,6 +1599,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'enable_merged_prefill',
'input_tokens_padded_tensor',
'block_list',
'block_mapping',
'block_usage',
Expand Down Expand Up @@ -1713,7 +1721,10 @@ def warmup_scenario(self,
profiler = setup_profiler()
profiler.start()
for _ in range(times):
origin_enable_merged_prefill = self.enable_merged_prefill
self.enable_merged_prefill = False
inputs = self.prepare_model_input(seqs)
self.enable_merged_prefill = origin_enable_merged_prefill
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
Expand Down

0 comments on commit bc69ebb

Please sign in to comment.