From 997a10a0deda51d0c8c4faef4bcd4081b04787d5 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 17 Dec 2024 00:11:40 +0200 Subject: [PATCH 01/25] update benchmark with bucketing strategy Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 50 ++++++++++++++++++++++++++++-- vllm/worker/hpu_model_runner.py | 1 + 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 13c62b8045785..cc90e13010579 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -5,6 +5,7 @@ import random import time from typing import List, Optional +import os import pandas as pd import torch @@ -71,17 +72,54 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, raise ValueError("output_len too small") # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) + if os.path.splitext(dataset_path)[1] == ".json": + with open(dataset_path) as f: + dataset = json.load(f) + elif os.path.splitext(dataset_path)[1] == ".pkl": + import pandas as pd + dataset = pd.read_pickle(dataset_path) + dataset = dataset[['input', 'output']].to_dict(orient="records") + for data in dataset: + data["conversations"] = [ + {"value": data["input"]}, + {"value": data["output"]} + ] + # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Shuffle the dataset. - random.shuffle(dataset) + random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] for data in dataset: if len(filtered_dataset) == num_requests: + if args.sort_by_len: + filtered_dataset = sorted(filtered_dataset, key=lambda x: x.prompt_len) + if args.bucket_selective: + length_map = {} + for i, request in enumerate(filtered_dataset): + length_map.setdefault(request.prompt_len, []).append(i) + ret = {} + for length, indices in length_map.items(): + bucket_size = (int(length / 128) + 1) * 128 + while len(indices) > 0: + i = indices.pop(0) + if ret.get(bucket_size, None) is None: + ret[bucket_size] = [] + ret[bucket_size].append(filtered_dataset[i]) + remain_len = bucket_size - length + while remain_len > 0: + if length_map.get(remain_len, None) is not None and len(length_map[remain_len]) > 0: + j = length_map[remain_len].pop(0) + ret[bucket_size].append(filtered_dataset[j]) + break + else: + remain_len -= 1 + # sort ret by key + ret = dict(sorted(ret.items(), key=lambda x: x[0])) + print("!!!!!!!!!!!!!!!sorted requests:", [(bucket_size, [i.prompt_len for i in req_list]) for bucket_size, req_list in ret.items()]) + filtered_dataset = [req for data in ret.items() for req in data[1]] break # Only keep the first two turns of each conversation. @@ -445,6 +483,12 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + parser.add_argument("--sort-by-len", + action='store_true', + default=False) + parser.add_argument("--bucket-selective", + action='store_true', + default=False) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..9e8af13c3c50e 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -893,6 +893,7 @@ def _prepare_prompt( max_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), self.block_size) + print(">>>>>>>> seq_lens are", seq_lens, "max_prompt_len is ", max_prompt_len) lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, From 79c8b8ef65d3fe2d1550b56c67a787442e689d62 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 17 Dec 2024 04:42:00 +0200 Subject: [PATCH 02/25] merge input tokens Signed-off-by: Chendi Xue --- vllm/worker/hpu_model_runner.py | 225 +++++++++++++++++++++++++++++++- 1 file changed, 221 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9e8af13c3c50e..e3713a81716ce 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -631,6 +631,8 @@ def __init__( self._set_gc_threshold() self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', 'true').lower() == 'true' + self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', + 'false').lower() == 'true' if vllm_config.speculative_config is not None \ and self.use_contiguous_pa: raise ValueError( @@ -884,8 +886,8 @@ def _prepare_prompt( block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - + slot_mapping[-1].append(slot) + max_query_len = max(query_lens) real_num_seqs = len(query_lens) assert max_query_len > 0 @@ -893,7 +895,6 @@ def _prepare_prompt( max_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), self.block_size) - print(">>>>>>>> seq_lens are", seq_lens, "max_prompt_len is ", max_prompt_len) lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, @@ -1012,6 +1013,221 @@ def _prepare_prompt( slot_mapping=slot_mapping, lora_ids=lora_ids) + def _prepare_prompt_merged( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PreparePromptMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size + seq_data = seq_group_metadata.seq_data[seq_id] + context_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + + # NOTE: This only works for oooooooxxx style attention. + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] + prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert context_len == 0 + + # actual prompt lens + context_lens.append(context_len) + query_lens.append(seq_len - context_len) + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(context_len, seq_len))) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_kwargs_list.append(mm_kwargs) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, seq_len - self.sliding_window) + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + #input_tokens + #input_positions + #slot_mapping + #seq_lens + #context_lens + #prefix_block_list + + input_tokens_merged = list(itertools.chain.from_iterable(input_tokens)) + input_tokens_merged = [input_tokens_merged] + input_positions_merged = list(itertools.chain.from_iterable(input_positions)) + input_positions_merged = [input_positions_merged] + slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) + slot_mapping_merged = [slot_mapping_merged] + context_lens_merged = [sum(context_lens)] + total_seq_lens = [sum(seq_lens)] + total_query_lens = [sum(query_lens)] + + max_query_len = max(total_query_lens) + real_num_seqs = len(total_query_lens) + assert max_query_len > 0 + + print(">>>> seq_lens", seq_lens, "max_query_len", max_query_len, "real_num_seqs", real_num_seqs) + + max_prompt_len = max( + self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), + self.block_size) + + prefix_block_list_tensor = None + + input_tokens_tensor = make_tensor_with_pad(input_tokens_merged, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') + + input_positions = make_tensor_with_pad(input_positions_merged, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') + + slot_mapping = make_tensor_with_pad(slot_mapping_merged, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device='cpu') + + seq_lens_tensor = torch.tensor(total_seq_lens, + dtype=torch.long, + device='cpu') + + context_lens_tensor = torch.tensor(context_lens_merged, + dtype=torch.long, + device='cpu') + + # Note: num_prefill_tokens is calculated using the length of + # input_tokens after padding. + num_prefill_tokens = input_tokens_tensor.numel() + input_tokens_tensor = input_tokens_tensor.to( # type: ignore + self.device, non_blocking=True) + input_positions = input_positions.to( # type: ignore + self.device, non_blocking=True) + slot_mapping = slot_mapping.to( # type: ignore + self.device, non_blocking=True) + seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) + context_lens_tensor = context_lens_tensor.to(self.device, + non_blocking=True) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + block_list=prefix_block_list_tensor, + block_mapping=None, + block_usage=None, + block_indices=None, + block_offsets=None, + block_scales=None, + block_groups=None, + attn_bias=None, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + context_lens_tensor=context_lens_tensor, + num_prefills=real_num_seqs, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps= + None # FIXME(kzawora): mutli-modality will not work here + ) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + for t in multi_modal_kwargs: + if torch.is_tensor(multi_modal_kwargs[t]): + multi_modal_kwargs[t] = multi_modal_kwargs[t].to( + self.device, non_blocking=True) + + return PreparePromptMetadata(input_tokens=input_tokens_tensor, + input_positions=input_positions, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping, + lora_ids=[]) + def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1224,6 +1440,7 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. + prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt ( input_tokens, input_positions, @@ -1236,7 +1453,7 @@ def prepare_input_tensors( multi_modal_kwargs, slot_mapping, lora_ids, - ) = self._prepare_prompt(prefill_reqs) + ) = prepare_prompt_impl(prefill_reqs) ( decode_input_tokens, decode_input_positions, From 552e294b92e2be31fcf4dfde9d43f5d95514811c Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 18 Dec 2024 03:18:57 +0200 Subject: [PATCH 03/25] Enable merged prefill Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 62 ++++++++++++++++++++++------- vllm/worker/hpu_model_runner.py | 27 +++++++++---- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index c5b57cb1967f0..f5984caf96c5c 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -83,6 +83,8 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] + enable_merged_prefill: bool = False + input_tokens_padded_tensor: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -213,6 +215,7 @@ def forward( block_offsets = kwargs.get('block_offsets', None) seq_lens_tensor = kwargs.get('seq_lens_tensor', None) attn_bias = kwargs.get('attn_bias', None) + enable_merged_prefill = attn_metadata.enable_merged_prefill if block_indices is None: block_indices = attn_metadata.block_indices if block_offsets is None: @@ -221,20 +224,48 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if attn_metadata.is_prompt: - key = key.unflatten(0, (block_indices.size(0), -1)) - value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: - key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) + if enable_merged_prefill: + if attn_metadata.is_prompt: + padded_shape = attn_metadata.input_tokens_padded_tensor + seq_lens_tensor_list = seq_lens_tensor.tolist() + padded_key_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=key.device, dtype=key.dtype) + padded_value_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=query.device, dtype=key.dtype) + start = 0 + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + for i in range(padded_shape[0]): + padded_key_tensor[i, :seq_lens_tensor_list[i]].copy_(key[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True) + padded_value_tensor[i, :seq_lens_tensor_list[i]].copy_(value[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True) + start = start + seq_lens_tensor_list[i] + # shape will be [batch_size * entire_seq_len, num_kv_heads, head_size] + # then reshape it to [n_blocks, block_size, num_kv_heads * head_size] + padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) + padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) + seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices, + block_offsets) + else: + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + seq_lens_tensor_merged = seq_lens_tensor + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) if attn_metadata.is_prompt: # Prompt run. @@ -256,6 +287,7 @@ def forward( else: attn_bias = None + #self.fused_scaled_dot_product_attention = None out = ops.prompt_attention( query.view(query_shape), key.view(kv_shape), @@ -266,7 +298,7 @@ def forward( matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, - valid_seq_lengths=seq_lens_tensor, + valid_seq_lengths=seq_lens_tensor_merged, fsdpa_op=self.fused_scaled_dot_product_attention, ) else: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index e3713a81716ce..6305e96a24919 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1143,33 +1143,36 @@ def _prepare_prompt_merged( real_num_seqs = len(total_query_lens) assert max_query_len > 0 - print(">>>> seq_lens", seq_lens, "max_query_len", max_query_len, "real_num_seqs", real_num_seqs) - - max_prompt_len = max( + merged_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), self.block_size) + max_prompt_len = max( + self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), + self.block_size) prefix_block_list_tensor = None input_tokens_tensor = make_tensor_with_pad(input_tokens_merged, - max_len=max_prompt_len, + max_len=merged_prompt_len, pad=0, dtype=torch.long, device='cpu') input_positions = make_tensor_with_pad(input_positions_merged, - max_len=max_prompt_len, + max_len=merged_prompt_len, pad=0, dtype=torch.long, device='cpu') + + input_tokens_padded_tensor = torch.tensor([len(seq_lens), max_prompt_len], dtype=torch.long, device='cpu') - slot_mapping = make_tensor_with_pad(slot_mapping_merged, + slot_mapping = make_tensor_with_pad(slot_mapping, max_len=max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, device='cpu') - seq_lens_tensor = torch.tensor(total_seq_lens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') @@ -1189,9 +1192,11 @@ def _prepare_prompt_merged( seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) context_lens_tensor = context_lens_tensor.to(self.device, non_blocking=True) - + input_tokens_padded_tensor = input_tokens_padded_tensor.to(self.device, + non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, + enable_merged_prefill=True, block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, @@ -1207,6 +1212,7 @@ def _prepare_prompt_merged( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + input_tokens_padded_tensor=input_tokens_padded_tensor, multi_modal_placeholder_index_maps= None # FIXME(kzawora): mutli-modality will not work here ) @@ -1593,6 +1599,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', + 'enable_merged_prefill', + 'input_tokens_padded_tensor', 'block_list', 'block_mapping', 'block_usage', @@ -1713,7 +1721,10 @@ def warmup_scenario(self, profiler = setup_profiler() profiler.start() for _ in range(times): + origin_enable_merged_prefill = self.enable_merged_prefill + self.enable_merged_prefill = False inputs = self.prepare_model_input(seqs) + self.enable_merged_prefill = origin_enable_merged_prefill is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: From bd8751237f3dcea14ef17a88cc1b0225b264307b Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 19 Dec 2024 01:54:31 +0200 Subject: [PATCH 04/25] accuracy issue is fixed Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 65 +++++++++++++++++++++++++++-- vllm/worker/hpu_model_runner.py | 32 ++++++++++++-- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index f5984caf96c5c..e724a087fbe83 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -29,6 +29,56 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") +def prompt_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + valid_seq_lengths: Optional[torch.Tensor] = None, + fsdpa_op = None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + if attn_bias is not None: + attn_bias = attn_bias.expand(query.size(0), query_heads, query.size(2), key.size(2)) + if attn_bias is not None or fsdpa_op is None: + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None: + attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) + if attn_bias is not None: + attn_weights.add_(attn_bias) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + else: + VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' + # TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: + key = ops.repeat_kv(key, int(query_heads // kv_heads)) + value = ops.repeat_kv(value, int(query_heads // kv_heads)) + if attn_bias is not None: + attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=True, + scale=scale, softmax_mode=softmax_mode, recompute_mode=recompute_mode, + valid_sequence_lengths=valid_seq_lengths, padding_side='right') + attn_weights = attn_weights.transpose(1, 2) + return attn_weights class HPUAttentionBackend(AttentionBackend): @@ -228,8 +278,10 @@ def forward( if attn_metadata.is_prompt: padded_shape = attn_metadata.input_tokens_padded_tensor seq_lens_tensor_list = seq_lens_tensor.tolist() - padded_key_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=key.device, dtype=key.dtype) - padded_value_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=query.device, dtype=key.dtype) + padded_key_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), + device=key.device, dtype=key.dtype) + padded_value_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), + device=query.device, dtype=key.dtype) start = 0 # we need to copy the key and value tensors to the padded tensors # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] @@ -284,11 +336,16 @@ def forward( attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) + elif enable_merged_prefill: + pass else: attn_bias = None - #self.fused_scaled_dot_product_attention = None - out = ops.prompt_attention( + if enable_merged_prefill: + prompt_attn_func = prompt_attention + else: + prompt_attn_func = ops.prompt_attention + out = prompt_attn_func( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6305e96a24919..27e0aa4e98ba1 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -212,6 +212,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() + self.merged_prefill_attn_mask_compute = os.getenv('VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE', + '1').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -287,6 +289,28 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_merged_attn_bias(self, attn_metadata, batch_size, max_seq_len, device,):# create a 2D causal attn mask to ensure I can only attend to the past + if attn_metadata is None or not attn_metadata.is_prompt: + return attn_metadata + if not self.merged_prefill_attn_mask_compute: + return attn_metadata + #TODO: Support batch_size > 1 + seq_lens = attn_metadata.seq_lens_tensor.tolist() + causal_attn_mask_tensor = torch.ones((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) + start = 0 + for i in range(batch_size): + for seq_len in seq_lens: + # create triangular mask for each sequence + causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) + causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].copy_(causal_mask) + start += seq_len + causal_attn_mask_tensor = (torch.zeros_like(causal_attn_mask_tensor, device=device, dtype=self.dtype).masked_fill_( + causal_attn_mask_tensor, -10000)) # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = causal_attn_mask_tensor.view(causal_attn_mask_tensor.shape[0], 1, causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + + attn_metadata = attn_metadata._replace(attn_bias=causal_attn_mask_tensor) + return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = torch.arange(0, self.block_size, @@ -340,7 +364,9 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata.is_prompt: + if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: + attn_metadata = self._set_merged_attn_bias(attn_metadata, batch_size, seq_len, device) + elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) else: @@ -886,8 +912,8 @@ def _prepare_prompt( block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - + slot_mapping[-1].append(slot) + max_query_len = max(query_lens) real_num_seqs = len(query_lens) assert max_query_len > 0 From 1caf26642bd86961946069b5a61c4edf430b0dbc Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 01:19:12 +0200 Subject: [PATCH 05/25] use logical_and_ Signed-off-by: Chendi Xue --- vllm/worker/hpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 27e0aa4e98ba1..a51e79f8b4b12 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -302,7 +302,7 @@ def _set_merged_attn_bias(self, attn_metadata, batch_size, max_seq_len, device,) for seq_len in seq_lens: # create triangular mask for each sequence causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) - causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].copy_(causal_mask) + causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].logical_and_(causal_mask) start += seq_len causal_attn_mask_tensor = (torch.zeros_like(causal_attn_mask_tensor, device=device, dtype=self.dtype).masked_fill_( causal_attn_mask_tensor, -10000)) # should be math(-inf) but -10000 is used for numerical stability From c5287363e9f6117cc29de6e26899e67d9344bf87 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 04:33:49 +0200 Subject: [PATCH 06/25] update Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 34 ++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index e724a087fbe83..69ec99f7f2ffe 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -10,6 +10,7 @@ import vllm_hpu_extension.ops as ops from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, VLLMKVCache) +from vllm_hpu_extension.cache_ops import insert_or_update_cache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -47,15 +48,14 @@ def prompt_attention( value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - if attn_bias is not None: - attn_bias = attn_bias.expand(query.size(0), query_heads, query.size(2), key.size(2)) if attn_bias is not None or fsdpa_op is None: + #if fsdpa_op is None: if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) if attn_bias is not None: - attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + attn_bias = attn_bias.unsqueeze(1) attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) @@ -71,15 +71,33 @@ def prompt_attention( key = ops.repeat_kv(key, int(query_heads // kv_heads)) value = ops.repeat_kv(value, int(query_heads // kv_heads)) if attn_bias is not None: - attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + attn_bias = attn_bias.unsqueeze(1) softmax_mode = 'fast' recompute_mode = True - attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=True, + attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=False, scale=scale, softmax_mode=softmax_mode, recompute_mode=recompute_mode, - valid_sequence_lengths=valid_seq_lengths, padding_side='right') + valid_sequence_lengths=None, padding_side='right') attn_weights = attn_weights.transpose(1, 2) return attn_weights + +class VLLMKVCache_dev(torch.nn.Module): + + def __init__(self): + super(VLLMKVCache_dev, self).__init__() + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' + + def forward(self, input, cache, block_indices, block_offset): + insert_or_update_cache(input, cache, block_indices, block_offset) + return cache + + def fetch_from_cache(self, cache, blocks): + if self.use_contiguous_pa: + return cache[:blocks.size(0)] + else: + return cache.index_select(0, blocks) + class HPUAttentionBackend(AttentionBackend): @staticmethod @@ -188,8 +206,8 @@ def __init__( self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - self.k_cache = VLLMKVCache() - self.v_cache = VLLMKVCache() + self.k_cache = VLLMKVCache_dev() + self.v_cache = VLLMKVCache_dev() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads From 510722e74e04921a8ddd92b8908995fadfbe5070 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 04:34:09 +0200 Subject: [PATCH 07/25] update benchmark Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index cc90e13010579..ed555284ba178 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -88,7 +88,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Shuffle the dataset. - random.shuffle(dataset) + #random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] @@ -189,7 +189,7 @@ def run_vllm( use_beam_search = False if not use_beam_search: - for _ in range(2): + for _ in range(3): start = time.perf_counter() llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() From f6c0c84a012597fe8ff595842d6eba45dfe5d769 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 04:42:22 +0200 Subject: [PATCH 08/25] update Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 69ec99f7f2ffe..95fb115228c00 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -48,8 +48,8 @@ def prompt_attention( value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - if attn_bias is not None or fsdpa_op is None: - #if fsdpa_op is None: + #if attn_bias is not None or fsdpa_op is None: + if fsdpa_op is None: if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -80,24 +80,6 @@ def prompt_attention( attn_weights = attn_weights.transpose(1, 2) return attn_weights - -class VLLMKVCache_dev(torch.nn.Module): - - def __init__(self): - super(VLLMKVCache_dev, self).__init__() - self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', - 'true').lower() == 'true' - - def forward(self, input, cache, block_indices, block_offset): - insert_or_update_cache(input, cache, block_indices, block_offset) - return cache - - def fetch_from_cache(self, cache, blocks): - if self.use_contiguous_pa: - return cache[:blocks.size(0)] - else: - return cache.index_select(0, blocks) - class HPUAttentionBackend(AttentionBackend): @staticmethod @@ -206,8 +188,8 @@ def __init__( self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - self.k_cache = VLLMKVCache_dev() - self.v_cache = VLLMKVCache_dev() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads From b1335426a69f454208cfe1eb4c25ff6aa519b441 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 19:46:14 +0200 Subject: [PATCH 09/25] rewrite split function to make fp8 work Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 2 ++ vllm/attention/backends/hpu_attn.py | 50 ++++++++++++++++++----------- vllm/worker/hpu_model_runner.py | 6 ---- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index ed555284ba178..9a2ccc0b4783a 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -159,6 +159,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, expected_output_len=output_len, multi_modal_data=multi_modal_data)) + for i, data in enumerate(filtered_dataset): + print(i, data.prompt) return filtered_dataset diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 95fb115228c00..be8abe92d2a0d 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -30,6 +30,28 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") +def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) + + padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) + p3d = (0, 0, 0, 0, 0, target_length - padded_tensor.size(1)) + padded_tensor = torch.nn.functional.pad(padded_tensor, p3d, value=0) + return padded_tensor + +def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) + + start = 0 + for i in range(len(seq_lens_tensor_list)): + padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] + start = start + seq_lens_tensor_list[i] + + return padded_tensor + def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -74,9 +96,9 @@ def prompt_attention( attn_bias = attn_bias.unsqueeze(1) softmax_mode = 'fast' recompute_mode = True - attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=False, - scale=scale, softmax_mode=softmax_mode, recompute_mode=recompute_mode, - valid_sequence_lengths=None, padding_side='right') + attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, + scale, softmax_mode, recompute_mode, + None, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights @@ -134,7 +156,6 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] enable_merged_prefill: bool = False - input_tokens_padded_tensor: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -276,24 +297,17 @@ def forward( attn_bias = attn_metadata.attn_bias if enable_merged_prefill: if attn_metadata.is_prompt: - padded_shape = attn_metadata.input_tokens_padded_tensor - seq_lens_tensor_list = seq_lens_tensor.tolist() - padded_key_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), - device=key.device, dtype=key.dtype) - padded_value_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), - device=query.device, dtype=key.dtype) - start = 0 + max_len=attn_metadata.slot_mapping.size(1) + seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() # we need to copy the key and value tensors to the padded tensors # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - for i in range(padded_shape[0]): - padded_key_tensor[i, :seq_lens_tensor_list[i]].copy_(key[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True) - padded_value_tensor[i, :seq_lens_tensor_list[i]].copy_(value[start: start + seq_lens_tensor_list[i], :, :], non_blocking=True) - start = start + seq_lens_tensor_list[i] - # shape will be [batch_size * entire_seq_len, num_kv_heads, head_size] - # then reshape it to [n_blocks, block_size, num_kv_heads * head_size] + padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list) + padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list) padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0) + + #seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0) + seq_lens_tensor_merged = seq_lens_tensor if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a51e79f8b4b12..1891ea91d6203 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1189,8 +1189,6 @@ def _prepare_prompt_merged( pad=0, dtype=torch.long, device='cpu') - - input_tokens_padded_tensor = torch.tensor([len(seq_lens), max_prompt_len], dtype=torch.long, device='cpu') slot_mapping = make_tensor_with_pad(slot_mapping, max_len=max_prompt_len, @@ -1218,8 +1216,6 @@ def _prepare_prompt_merged( seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) context_lens_tensor = context_lens_tensor.to(self.device, non_blocking=True) - input_tokens_padded_tensor = input_tokens_padded_tensor.to(self.device, - non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, enable_merged_prefill=True, @@ -1238,7 +1234,6 @@ def _prepare_prompt_merged( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - input_tokens_padded_tensor=input_tokens_padded_tensor, multi_modal_placeholder_index_maps= None # FIXME(kzawora): mutli-modality will not work here ) @@ -1626,7 +1621,6 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'seq_lens_tensor', 'context_lens_tensor', 'enable_merged_prefill', - 'input_tokens_padded_tensor', 'block_list', 'block_mapping', 'block_usage', From 2d6ceb991cb81edcd635e615b91ee30d1c828450 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 20 Dec 2024 22:41:03 +0000 Subject: [PATCH 10/25] Fix accuracy issue Signed-off-by: Chendi.Xue Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 4 ++-- vllm/attention/backends/hpu_attn.py | 28 ++-------------------------- vllm/model_executor/models/llama.py | 8 +++++++- vllm/model_executor/models/utils.py | 23 +++++++++++++++++++++++ vllm/worker/hpu_model_runner.py | 3 +++ 5 files changed, 37 insertions(+), 29 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 9a2ccc0b4783a..dea7fe96f32ef 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -159,8 +159,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, expected_output_len=output_len, multi_modal_data=multi_modal_data)) - for i, data in enumerate(filtered_dataset): - print(i, data.prompt) + # for i, data in enumerate(filtered_dataset): + # print(i, data.prompt) return filtered_dataset diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index be8abe92d2a0d..abc4222715380 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -19,6 +19,7 @@ HPUPagedAttentionMetadata) from vllm.logger import init_logger from vllm.utils import is_fake_hpu +from vllm.model_executor.models.utils import split_and_pad_to_length logger = init_logger(__name__) @@ -30,28 +31,6 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) - - padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) - p3d = (0, 0, 0, 0, 0, target_length - padded_tensor.size(1)) - padded_tensor = torch.nn.functional.pad(padded_tensor, p3d, value=0) - return padded_tensor - -def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) - - start = 0 - for i in range(len(seq_lens_tensor_list)): - padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] - start = start + seq_lens_tensor_list[i] - - return padded_tensor - def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -306,8 +285,6 @@ def forward( padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - #seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0) - seq_lens_tensor_merged = seq_lens_tensor if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -320,7 +297,6 @@ def forward( if attn_metadata.is_prompt: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) - seq_lens_tensor_merged = seq_lens_tensor if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -369,7 +345,7 @@ def forward( matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, - valid_seq_lengths=seq_lens_tensor_merged, + valid_seq_lengths=seq_lens_tensor, fsdpa_op=self.fused_scaled_dot_product_attention, ) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 44d34a4e3f20a..c4e1f0778ae2b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -56,7 +56,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, split_and_pad_to_length) is_hpu = current_platform.is_hpu() @@ -490,6 +490,12 @@ def forward( "residual": residual }) + # we need to split result before do RMSNorm + if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: + max_len=attn_metadata.slot_mapping.size(1) + seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() + hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) + residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 269b66806adf4..72d01e1ca4906 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -666,3 +666,26 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] + +def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) + + padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) + pad_shape = [0] * (input.dim() - 1) * 2 + pad_shape += [0, target_length - padded_tensor.size(1)] + padded_tensor = torch.nn.functional.pad(padded_tensor, pad_shape, value=0) + return padded_tensor + +def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) + + start = 0 + for i in range(len(seq_lens_tensor_list)): + padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] + start = start + seq_lens_tensor_list[i] + + return padded_tensor \ No newline at end of file diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 1891ea91d6203..6ab14827b1a43 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1168,6 +1168,9 @@ def _prepare_prompt_merged( max_query_len = max(total_query_lens) real_num_seqs = len(total_query_lens) assert max_query_len > 0 + + # print("input_tokens_merged: ", input_tokens_merged) + # print("input_positions_merged: ", input_positions_merged) merged_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), From 116dc6c583fb0d06a9c723d76c6e7a4291a76e55 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 6 Jan 2025 22:07:18 +0200 Subject: [PATCH 11/25] clean up codes in hpu-attn Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 32 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index abc4222715380..2ab8eaa7cf340 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -274,25 +274,23 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if enable_merged_prefill: - if attn_metadata.is_prompt: - max_len=attn_metadata.slot_mapping.size(1) - seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list) - padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list) - padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) + if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: + max_len=attn_metadata.slot_mapping.size(1) + seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list) + padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list) + padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) + padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: - key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) - key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices, - block_offsets) + key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices, + block_offsets) else: if attn_metadata.is_prompt: key = key.unflatten(0, (block_indices.size(0), -1)) From fade38694954313acc8b858844db4bc33991dc46 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 7 Jan 2025 04:19:01 +0200 Subject: [PATCH 12/25] update warming up strategy for merged_prefill Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 4 +- vllm/attention/backends/hpu_attn.py | 37 +++++--- vllm/worker/hpu_model_runner.py | 142 +++++++++++++++++++++++----- 3 files changed, 140 insertions(+), 43 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index dea7fe96f32ef..f7347ac9a391a 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -191,9 +191,9 @@ def run_vllm( use_beam_search = False if not use_beam_search: - for _ in range(3): + for _ in range(1): start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=False) end = time.perf_counter() else: prompts = [request.prompt for request in requests] diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 2ab8eaa7cf340..ad5afafdef81a 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -31,6 +31,7 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") + def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -42,7 +43,7 @@ def prompt_attention( softmax_op=torch.softmax, matmul_av_op=torch.matmul, valid_seq_lengths: Optional[torch.Tensor] = None, - fsdpa_op = None, + fsdpa_op=None, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -65,7 +66,8 @@ def prompt_attention( if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: - VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' + VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( + 'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' # TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: @@ -76,11 +78,12 @@ def prompt_attention( softmax_mode = 'fast' recompute_mode = True attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, - scale, softmax_mode, recompute_mode, - None, 'right') + scale, softmax_mode, recompute_mode, None, + 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights + class HPUAttentionBackend(AttentionBackend): @staticmethod @@ -275,22 +278,26 @@ def forward( if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: - max_len=attn_metadata.slot_mapping.size(1) + max_len = attn_metadata.slot_mapping.size(1) seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() # we need to copy the key and value tensors to the padded tensors # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list) - padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list) - padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) + padded_key_tensor = split_and_pad_to_length( + key, max_len, seq_lens_tensor_list) + padded_value_tensor = split_and_pad_to_length( + value, max_len, seq_lens_tensor_list) + padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten( + 0, (block_indices.size(0), -1)) + padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten( + 0, (block_indices.size(0), -1)) key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices, - block_offsets) + key_cache = self.k_cache(padded_key_tensor, key_cache, + block_indices, block_offsets) + value_cache = self.v_cache(padded_value_tensor, value_cache, + block_indices, block_offsets) else: if attn_metadata.is_prompt: key = key.unflatten(0, (block_indices.size(0), -1)) @@ -303,9 +310,9 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) + block_offsets) value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) + block_offsets) if attn_metadata.is_prompt: # Prompt run. diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6ab14827b1a43..b70eaa4e01ebb 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -19,7 +19,7 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing import HPUBucketingContext, generate_prompt_buckets from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -205,6 +205,59 @@ def get_child(parent, suffix, is_list=False): } +class HPUBucketingContextWithMergedPrefill(HPUBucketingContext): + + def generate_prompt_buckets(self): + print( + "HPUBucketingContextWithMergedPrefill - generate_prompt_buckets is called" + ) + + prompt_bs_bucket_cfg = self.global_state.prompt_bs_bucket_cfg + prompt_seq_bucket_cfg = self.global_state.prompt_seq_bucket_cfg + print("prompt_seq_bucket_cfg: ", prompt_seq_bucket_cfg) + origin_max_prompt_len = prompt_seq_bucket_cfg[2] + max_prompt_len = prompt_bs_bucket_cfg[2] * prompt_seq_bucket_cfg[2] + max_prompt_len = min(self.max_num_batched_tokens, max_prompt_len) + prompt_seq_bucket_cfg[2] = max_prompt_len + + prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + prompt_bs_bucket_cfg, + prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + print("prompt_buckets: ", prompt_buckets) + # expand + self.global_state.prompt_buckets = [] + VLLM_PROMPT_BS_BUCKET_MAX = int( + os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', 16)) + for bucket in prompt_buckets: + bs = 1 + while bs <= VLLM_PROMPT_BS_BUCKET_MAX: + seq_len = bucket[1] // bs + if seq_len <= 32: + bs = bs * 2 + continue + self.global_state.prompt_buckets.append( + (bs * bucket[0], seq_len)) + bs = bs * 2 + + self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len, self.global_state.prompt_buckets)) + + msg = (f"Generated {len(self.global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: " + f"{list(sorted(self.global_state.prompt_buckets))}") + print(msg) + + # msg = (f"Omitted {len(prompt_omitted_buckets)} " + # "prompt buckets due to exceeded token budget " + # f"(max_num_batched_tokens={self.max_num_batched_tokens})") + # print(msg) + + # msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + # print(msg) + + class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager, layer_names): @@ -212,8 +265,9 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() - self.merged_prefill_attn_mask_compute = os.getenv('VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE', - '1').lower() in ['1', 'true'] + self.merged_prefill_attn_mask_compute = os.getenv( + 'VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE', + '1').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -289,26 +343,46 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata - def _set_merged_attn_bias(self, attn_metadata, batch_size, max_seq_len, device,):# create a 2D causal attn mask to ensure I can only attend to the past + def _set_merged_attn_bias( + self, + attn_metadata, + batch_size, + max_seq_len, + device, + ): # create a 2D causal attn mask to ensure I can only attend to the past if attn_metadata is None or not attn_metadata.is_prompt: return attn_metadata if not self.merged_prefill_attn_mask_compute: return attn_metadata #TODO: Support batch_size > 1 seq_lens = attn_metadata.seq_lens_tensor.tolist() - causal_attn_mask_tensor = torch.ones((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) + causal_attn_mask_tensor = torch.ones( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=device) start = 0 for i in range(batch_size): for seq_len in seq_lens: # create triangular mask for each sequence - causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) - causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].logical_and_(causal_mask) + causal_mask = torch.triu(torch.ones((seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) + causal_attn_mask_tensor[i][start:start + seq_len, start:start + + seq_len].logical_and_(causal_mask) start += seq_len - causal_attn_mask_tensor = (torch.zeros_like(causal_attn_mask_tensor, device=device, dtype=self.dtype).masked_fill_( - causal_attn_mask_tensor, -10000)) # should be math(-inf) but -10000 is used for numerical stability - causal_attn_mask_tensor = causal_attn_mask_tensor.view(causal_attn_mask_tensor.shape[0], 1, causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) - - attn_metadata = attn_metadata._replace(attn_bias=causal_attn_mask_tensor) + causal_attn_mask_tensor = ( + torch.zeros_like(causal_attn_mask_tensor, + device=device, + dtype=self.dtype).masked_fill_( + causal_attn_mask_tensor, -10000) + ) # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = causal_attn_mask_tensor.view( + causal_attn_mask_tensor.shape[0], 1, + causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + + attn_metadata = attn_metadata._replace( + attn_bias=causal_attn_mask_tensor) return attn_metadata def _set_block_mapping(self, metadata, batch_size, device, dtype): @@ -365,7 +439,9 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: - attn_metadata = self._set_merged_attn_bias(attn_metadata, batch_size, seq_len, device) + attn_metadata = self._set_merged_attn_bias(attn_metadata, + batch_size, seq_len, + device) elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) @@ -403,6 +479,8 @@ def forward(self, *args, **kwargs): LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) + print("Warming up HPU Graph - input_ids: ", input_ids.shape, + "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -648,17 +726,26 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, - self.max_num_prefill_seqs, - self.block_size, - self.max_num_batched_tokens) + self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', + 'false').lower() == 'true' + # self.bucketing_ctx = HPUBucketingContext( + # self.max_num_seqs, + # self.max_num_prefill_seqs, + # self.block_size, + # self.max_num_batched_tokens) + if self.enable_merged_prefill: + self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) + else: + self.bucketing_ctx = HPUBucketingContext( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', 'true').lower() == 'true' - self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', - 'false').lower() == 'true' if vllm_config.speculative_config is not None \ and self.use_contiguous_pa: raise ValueError( @@ -1154,24 +1241,25 @@ def _prepare_prompt_merged( #seq_lens #context_lens #prefix_block_list - + input_tokens_merged = list(itertools.chain.from_iterable(input_tokens)) input_tokens_merged = [input_tokens_merged] - input_positions_merged = list(itertools.chain.from_iterable(input_positions)) + input_positions_merged = list( + itertools.chain.from_iterable(input_positions)) input_positions_merged = [input_positions_merged] slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) slot_mapping_merged = [slot_mapping_merged] context_lens_merged = [sum(context_lens)] total_seq_lens = [sum(seq_lens)] total_query_lens = [sum(query_lens)] - + max_query_len = max(total_query_lens) real_num_seqs = len(total_query_lens) assert max_query_len > 0 # print("input_tokens_merged: ", input_tokens_merged) # print("input_positions_merged: ", input_positions_merged) - + merged_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), self.block_size) @@ -1432,6 +1520,7 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: + print("seq_group_metadata_list is empty") return self._model_input_cls(), None input_tokens = None @@ -1470,6 +1559,7 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. + #print("prefill_reqs: ", prefill_reqs, "decode_reqs: ", decode_reqs) prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt ( input_tokens, @@ -1672,8 +1762,11 @@ def profile_run(self) -> None: max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) + origin_enable_merged_prefill = self.enable_merged_prefill + self.enable_merged_prefill = False self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) + self.enable_merged_prefill = origin_enable_merged_prefill return def warmup_scenario(self, @@ -1744,10 +1837,7 @@ def warmup_scenario(self, profiler = setup_profiler() profiler.start() for _ in range(times): - origin_enable_merged_prefill = self.enable_merged_prefill - self.enable_merged_prefill = False inputs = self.prepare_model_input(seqs) - self.enable_merged_prefill = origin_enable_merged_prefill is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: From 911f14b2ad90397a4d08cce330449241f68912a9 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 7 Jan 2025 22:27:21 +0000 Subject: [PATCH 13/25] fix an accuracy issue caused by selected_token_index Signed-off-by: Chendi.Xue --- vllm/worker/hpu_model_runner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b70eaa4e01ebb..604a93b93a31d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -479,8 +479,10 @@ def forward(self, *args, **kwargs): LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) - print("Warming up HPU Graph - input_ids: ", input_ids.shape, - "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor) + if kwargs['attn_metadata'].is_prompt: + print("Warming up HPU Graph - input_ids: ", input_ids, + "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor, + "selected_token_indices: ", selected_token_indices) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -1613,7 +1615,10 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding - max_len = input_tokens.size(1) + if self.enable_merged_prefill: + max_len = slot_mapping.size(1) + else: + max_len = input_tokens.size(1) paddings = [max_len - q for q in query_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) From 612abed1ecc4c02dab3e0baeb638168139532322 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 7 Jan 2025 21:31:38 +0200 Subject: [PATCH 14/25] move tolist to llamamodel fwd Signed-off-by: Chendi Xue Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 2 +- vllm/model_executor/models/llama.py | 11 ++++++++--- vllm/worker/hpu_model_runner.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index ad5afafdef81a..82d6e5e3f3225 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -268,6 +268,7 @@ def forward( block_offsets = kwargs.get('block_offsets', None) seq_lens_tensor = kwargs.get('seq_lens_tensor', None) attn_bias = kwargs.get('attn_bias', None) + seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None) enable_merged_prefill = attn_metadata.enable_merged_prefill if block_indices is None: block_indices = attn_metadata.block_indices @@ -279,7 +280,6 @@ def forward( attn_bias = attn_metadata.attn_bias if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: max_len = attn_metadata.slot_mapping.size(1) - seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() # we need to copy the key and value tensors to the padded tensors # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] padded_key_tensor = split_and_pad_to_length( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c4e1f0778ae2b..e0800010db7ee 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -292,6 +292,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], + seq_lens_tensor_list: List[int], ) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(hidden_states, torch.Tensor): skip_split = hidden_states.size()[0] == 1 @@ -313,7 +314,8 @@ def forward( hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - attn_metadata=attn_metadata) + attn_metadata=attn_metadata, + seq_lens_tensor_list=seq_lens_tensor_list) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -479,11 +481,15 @@ def forward( import habana_frameworks.torch as htorch htorch.core.mark_step() + if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: + seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() + else: + seq_lens_tensor_list = None for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], - attn_metadata, residual) + attn_metadata, residual, seq_lens_tensor_list) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -493,7 +499,6 @@ def forward( # we need to split result before do RMSNorm if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: max_len=attn_metadata.slot_mapping.size(1) - seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) hidden_states, _ = self.norm(hidden_states, residual) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 604a93b93a31d..b0dac110b8ac3 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -480,7 +480,7 @@ def forward(self, *args, **kwargs): if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) if kwargs['attn_metadata'].is_prompt: - print("Warming up HPU Graph - input_ids: ", input_ids, + print("Warming up HPU Graph - input_ids: ", input_ids.shape, "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor, "selected_token_indices: ", selected_token_indices) hidden_states = self.model(*args, **kwargs) From a3602f2fe0f45c8f0f05f242ccd1d1c9994c8167 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Wed, 8 Jan 2025 00:48:21 +0000 Subject: [PATCH 15/25] use index_put with full block_indices Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 24 ++------- vllm/model_executor/models/llama.py | 17 ++---- vllm/model_executor/models/utils.py | 23 -------- vllm/worker/hpu_model_runner.py | 82 ++++++++++------------------- 4 files changed, 36 insertions(+), 110 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 82d6e5e3f3225..a153ef87641ef 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -10,7 +10,6 @@ import vllm_hpu_extension.ops as ops from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, VLLMKVCache) -from vllm_hpu_extension.cache_ops import insert_or_update_cache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -19,7 +18,6 @@ HPUPagedAttentionMetadata) from vllm.logger import init_logger from vllm.utils import is_fake_hpu -from vllm.model_executor.models.utils import split_and_pad_to_length logger = init_logger(__name__) @@ -268,7 +266,6 @@ def forward( block_offsets = kwargs.get('block_offsets', None) seq_lens_tensor = kwargs.get('seq_lens_tensor', None) attn_bias = kwargs.get('attn_bias', None) - seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None) enable_merged_prefill = attn_metadata.enable_merged_prefill if block_indices is None: block_indices = attn_metadata.block_indices @@ -278,25 +275,12 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: - max_len = attn_metadata.slot_mapping.size(1) - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_key_tensor = split_and_pad_to_length( - key, max_len, seq_lens_tensor_list) - padded_value_tensor = split_and_pad_to_length( - value, max_len, seq_lens_tensor_list) - padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten( - 0, (block_indices.size(0), -1)) - padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten( - 0, (block_indices.size(0), -1)) - + if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - key_cache = self.k_cache(padded_key_tensor, key_cache, + kv_cache, self.num_kv_heads, self.head_size) + key_cache = self.k_cache(key, key_cache, block_indices, block_offsets) - value_cache = self.v_cache(padded_value_tensor, value_cache, + value_cache = self.v_cache(value, value_cache, block_indices, block_offsets) else: if attn_metadata.is_prompt: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e0800010db7ee..44d34a4e3f20a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -56,7 +56,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, split_and_pad_to_length) + maybe_prefix) is_hpu = current_platform.is_hpu() @@ -292,7 +292,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - seq_lens_tensor_list: List[int], ) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(hidden_states, torch.Tensor): skip_split = hidden_states.size()[0] == 1 @@ -314,8 +313,7 @@ def forward( hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - attn_metadata=attn_metadata, - seq_lens_tensor_list=seq_lens_tensor_list) + attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -481,26 +479,17 @@ def forward( import habana_frameworks.torch as htorch htorch.core.mark_step() - if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: - seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() - else: - seq_lens_tensor_list = None for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], - attn_metadata, residual, seq_lens_tensor_list) + attn_metadata, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) - # we need to split result before do RMSNorm - if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: - max_len=attn_metadata.slot_mapping.size(1) - hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) - residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 72d01e1ca4906..269b66806adf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -666,26 +666,3 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] - -def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) - - padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) - pad_shape = [0] * (input.dim() - 1) * 2 - pad_shape += [0, target_length - padded_tensor.size(1)] - padded_tensor = torch.nn.functional.pad(padded_tensor, pad_shape, value=0) - return padded_tensor - -def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) - - start = 0 - for i in range(len(seq_lens_tensor_list)): - padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] - start = start + seq_lens_tensor_list[i] - - return padded_tensor \ No newline at end of file diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b0dac110b8ac3..868a1887db5b4 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -214,7 +214,6 @@ def generate_prompt_buckets(self): prompt_bs_bucket_cfg = self.global_state.prompt_bs_bucket_cfg prompt_seq_bucket_cfg = self.global_state.prompt_seq_bucket_cfg - print("prompt_seq_bucket_cfg: ", prompt_seq_bucket_cfg) origin_max_prompt_len = prompt_seq_bucket_cfg[2] max_prompt_len = prompt_bs_bucket_cfg[2] * prompt_seq_bucket_cfg[2] max_prompt_len = min(self.max_num_batched_tokens, max_prompt_len) @@ -226,23 +225,7 @@ def generate_prompt_buckets(self): prompt_seq_bucket_cfg, self.max_num_batched_tokens) - print("prompt_buckets: ", prompt_buckets) - # expand - self.global_state.prompt_buckets = [] - VLLM_PROMPT_BS_BUCKET_MAX = int( - os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', 16)) - for bucket in prompt_buckets: - bs = 1 - while bs <= VLLM_PROMPT_BS_BUCKET_MAX: - seq_len = bucket[1] // bs - if seq_len <= 32: - bs = bs * 2 - continue - self.global_state.prompt_buckets.append( - (bs * bucket[0], seq_len)) - bs = bs * 2 - - self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len, self.global_state.prompt_buckets)) + self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len and bucket[0] == 1, prompt_buckets)) msg = (f"Generated {len(self.global_state.prompt_buckets)} " f"prompt buckets [bs, seq]: " @@ -265,9 +248,6 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() - self.merged_prefill_attn_mask_compute = os.getenv( - 'VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE', - '1').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -352,8 +332,6 @@ def _set_merged_attn_bias( ): # create a 2D causal attn mask to ensure I can only attend to the past if attn_metadata is None or not attn_metadata.is_prompt: return attn_metadata - if not self.merged_prefill_attn_mask_compute: - return attn_metadata #TODO: Support batch_size > 1 seq_lens = attn_metadata.seq_lens_tensor.tolist() causal_attn_mask_tensor = torch.ones( @@ -425,9 +403,12 @@ def _set_block_scales(self, metadata, device): return metadata def _set_indices_and_offsets(self, metadata, block_size, is_prompt): - slot_mapping = metadata.slot_mapping.flatten() + if metadata.enable_merged_prefill and is_prompt: + slot_mapping = metadata.slot_mapping + else: + slot_mapping = metadata.slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: + if not metadata.enable_merged_prefill and is_prompt: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None else: @@ -481,8 +462,7 @@ def forward(self, *args, **kwargs): self._prepare_cos_sin(kwargs['positions']) if kwargs['attn_metadata'].is_prompt: print("Warming up HPU Graph - input_ids: ", input_ids.shape, - "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor, - "selected_token_indices: ", selected_token_indices) + "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor.shape, 'slot_mapping: ', kwargs['attn_metadata'].slot_mapping.shape, 'selected_token_indices: ', selected_token_indices) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -730,11 +710,6 @@ def __init__( self._mem_margin: Optional[int] = None self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' - # self.bucketing_ctx = HPUBucketingContext( - # self.max_num_seqs, - # self.max_num_prefill_seqs, - # self.block_size, - # self.max_num_batched_tokens) if self.enable_merged_prefill: self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -1244,13 +1219,14 @@ def _prepare_prompt_merged( #context_lens #prefix_block_list + slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) + slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID] + slot_mapping = [slot_mapping_merged] input_tokens_merged = list(itertools.chain.from_iterable(input_tokens)) input_tokens_merged = [input_tokens_merged] input_positions_merged = list( itertools.chain.from_iterable(input_positions)) input_positions_merged = [input_positions_merged] - slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) - slot_mapping_merged = [slot_mapping_merged] context_lens_merged = [sum(context_lens)] total_seq_lens = [sum(seq_lens)] total_query_lens = [sum(query_lens)] @@ -1284,11 +1260,15 @@ def _prepare_prompt_merged( device='cpu') slot_mapping = make_tensor_with_pad(slot_mapping, - max_len=max_prompt_len, + max_len=merged_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, device='cpu') + max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16')) + max_prefill_bs = max(max_prefill_bs, len(seq_lens)) + seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens)) + context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens)) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') @@ -1522,7 +1502,6 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: - print("seq_group_metadata_list is empty") return self._model_input_cls(), None input_tokens = None @@ -1561,7 +1540,6 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. - #print("prefill_reqs: ", prefill_reqs, "decode_reqs: ", decode_reqs) prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt ( input_tokens, @@ -1615,23 +1593,21 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding - if self.enable_merged_prefill: - max_len = slot_mapping.size(1) - else: + if not self.enable_merged_prefill: max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params.prompt_logprobs is not None \ - and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) - paddings = torch.tensor( - paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) + paddings = [max_len - q for q in query_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + paddings = torch.tensor( + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: lora_mapping = LoRAMapping( From 97ea32b0949d0e22e182f3d71adb635c03d2cbc7 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Wed, 8 Jan 2025 04:39:50 +0000 Subject: [PATCH 16/25] use fixed length for selected_token_indices Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 1 + vllm/worker/hpu_model_runner.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a153ef87641ef..c0d2cbe9818c3 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -136,6 +136,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] enable_merged_prefill: bool = False + actual_num_prefills: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 868a1887db5b4..f34319152dd09 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -461,8 +461,13 @@ def forward(self, *args, **kwargs): if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) if kwargs['attn_metadata'].is_prompt: + am = kwargs['attn_metadata'] print("Warming up HPU Graph - input_ids: ", input_ids.shape, - "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor.shape, 'slot_mapping: ', kwargs['attn_metadata'].slot_mapping.shape, 'selected_token_indices: ', selected_token_indices) + "seq_lens_tensor: ", am.seq_lens_tensor.shape, + "context_lens_tensor: ", am.context_lens_tensor.shape, + "enable_merged_prefill:", am.enable_merged_prefill, + "slot_mapping: ", am.slot_mapping.shape, + "selected_token_indices: ", selected_token_indices.shape) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -1235,8 +1240,6 @@ def _prepare_prompt_merged( real_num_seqs = len(total_query_lens) assert max_query_len > 0 - # print("input_tokens_merged: ", input_tokens_merged) - # print("input_positions_merged: ", input_positions_merged) merged_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), @@ -1264,6 +1267,9 @@ def _prepare_prompt_merged( pad=_PAD_SLOT_ID, dtype=torch.long, device='cpu') + actual_num_prefills_tensor = torch.tensor(len(seq_lens), + dtype=torch.long, + device='cpu') max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16')) max_prefill_bs = max(max_prefill_bs, len(seq_lens)) @@ -1289,9 +1295,12 @@ def _prepare_prompt_merged( seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) context_lens_tensor = context_lens_tensor.to(self.device, non_blocking=True) + actual_num_prefills_tensor = actual_num_prefills_tensor.to( + self.device, non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, enable_merged_prefill=True, + actual_num_prefills=actual_num_prefills_tensor, block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, @@ -1608,6 +1617,14 @@ def prepare_input_tensors( dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings) + else: + paddings = [0] * (num_prefills - sampling_metadata.selected_token_indices.size(0)) + paddings = torch.tensor( + paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices = \ + torch.cat((sampling_metadata.selected_token_indices, paddings), dim=0) if self.lora_config: lora_mapping = LoRAMapping( @@ -1695,6 +1712,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'seq_lens_tensor', 'context_lens_tensor', 'enable_merged_prefill', + 'actual_num_prefills', 'block_list', 'block_mapping', 'block_usage', @@ -2470,6 +2488,11 @@ def try_revert_dummy_output_tokens(): selected_token_indices=sampling_metadata. selected_token_indices) + # change the selected_token_indices shape after fwd, so hpu graph capture can use exactly same shape + if execute_model_kwargs['attn_metadata'].actual_num_prefills is not None: + actual_num_prefills = execute_model_kwargs['attn_metadata'].actual_num_prefills + sampling_metadata.selected_token_indices = sampling_metadata.selected_token_indices[:actual_num_prefills] + hidden_states = hidden_states[:actual_num_prefills] if self.lora_config: LoraMask.setLoraMask( lora_logits_mask.index_select( From 11ffc2f6c7b47f5d606222155b2b5e92f4acf692 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 8 Jan 2025 18:33:54 +0200 Subject: [PATCH 17/25] clean up hpu_attn codes Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 36 +++++++++++------------------ 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index c0d2cbe9818c3..8775eee9a797f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -65,7 +65,7 @@ def prompt_attention( attn_weights = attn_weights.flatten(1, 2) else: VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( - 'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' + 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' # TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: @@ -276,28 +276,20 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: + if attn_metadata.is_prompt and not enable_merged_prefill: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - key_cache = self.k_cache(key, key_cache, - block_indices, block_offsets) - value_cache = self.v_cache(value, value_cache, - block_indices, block_offsets) - else: - if attn_metadata.is_prompt: - key = key.unflatten(0, (block_indices.size(0), -1)) - value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: - key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) if attn_metadata.is_prompt: # Prompt run. From 2ef08d6131a3d7bceb1cb816a0753455a9a2d3a4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 8 Jan 2025 21:20:37 +0200 Subject: [PATCH 18/25] add CPU version attn_mask preparation Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 4 +-- vllm/worker/hpu_model_runner.py | 46 ++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 8775eee9a797f..2617f053b3b5d 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -308,8 +308,8 @@ def forward( attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) - elif enable_merged_prefill: - pass + # elif enable_merged_prefill: + # pass else: attn_bias = None diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f34319152dd09..c1ee5fc873e33 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -332,6 +332,8 @@ def _set_merged_attn_bias( ): # create a 2D causal attn mask to ensure I can only attend to the past if attn_metadata is None or not attn_metadata.is_prompt: return attn_metadata + if attn_metadata.attn_bias is not None: + return attn_metadata #TODO: Support batch_size > 1 seq_lens = attn_metadata.seq_lens_tensor.tolist() causal_attn_mask_tensor = torch.ones( @@ -465,6 +467,7 @@ def forward(self, *args, **kwargs): print("Warming up HPU Graph - input_ids: ", input_ids.shape, "seq_lens_tensor: ", am.seq_lens_tensor.shape, "context_lens_tensor: ", am.context_lens_tensor.shape, + "attn_bias: ", am.attn_bias.shape if am.attn_bias is not None else None, "enable_merged_prefill:", am.enable_merged_prefill, "slot_mapping: ", am.slot_mapping.shape, "selected_token_indices: ", selected_token_indices.shape) @@ -715,6 +718,8 @@ def __init__( self._mem_margin: Optional[int] = None self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' + self.enable_cpu_merged_prefill_attn = os.environ.get('VLLM_CPU_MERGED_PREFILL_ATTN', + 'false').lower() == 'true' if self.enable_merged_prefill: self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -1278,11 +1283,44 @@ def _prepare_prompt_merged( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') - - context_lens_tensor = torch.tensor(context_lens_merged, + context_lens_tensor = torch.tensor(context_lens, dtype=torch.long, device='cpu') - + ##### Create attn_bias in CPU ##### + + if self.enable_cpu_merged_prefill_attn: + #TODO: Support batch_size > 1 + batch_size = 1 + max_seq_len = merged_prompt_len + device = 'cpu' + causal_attn_mask_tensor = torch.ones( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=device) + start = 0 + for i in range(batch_size): + for seq_len in seq_lens: + # create triangular mask for each sequence + causal_mask = torch.triu(torch.ones((seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) + causal_attn_mask_tensor[i][start:start + seq_len, start:start + + seq_len].logical_and_(causal_mask) + start += seq_len + causal_attn_mask_tensor = ( + torch.zeros_like(causal_attn_mask_tensor, + device=device, + dtype=self.model_config.dtype).masked_fill_( + causal_attn_mask_tensor, -10000) + ) # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = causal_attn_mask_tensor.view( + causal_attn_mask_tensor.shape[0], 1, + causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + causal_attn_mask_tensor = causal_attn_mask_tensor.to(self.device, non_blocking=True) + else: + causal_attn_mask_tensor = None + ######################## # Note: num_prefill_tokens is calculated using the length of # input_tokens after padding. num_prefill_tokens = input_tokens_tensor.numel() @@ -1308,7 +1346,7 @@ def _prepare_prompt_merged( block_offsets=None, block_scales=None, block_groups=None, - attn_bias=None, + attn_bias=causal_attn_mask_tensor, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, context_lens_tensor=context_lens_tensor, From 405243a46cb9639237c84102c7102463194bb969 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 9 Jan 2025 00:22:56 +0200 Subject: [PATCH 19/25] update CPU version attn_bias prepration and clean up Signed-off-by: Chendi Xue --- benchmarks/benchmark_throughput.py | 29 ++------ vllm/attention/backends/hpu_attn.py | 52 +++++--------- vllm/worker/hpu_model_runner.py | 101 +++++++++++++++++++++------- 3 files changed, 99 insertions(+), 83 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index f7347ac9a391a..d3e5c6bdf35f6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -88,38 +88,15 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Shuffle the dataset. - #random.shuffle(dataset) + random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] + prompt_lens = [] for data in dataset: if len(filtered_dataset) == num_requests: if args.sort_by_len: filtered_dataset = sorted(filtered_dataset, key=lambda x: x.prompt_len) - if args.bucket_selective: - length_map = {} - for i, request in enumerate(filtered_dataset): - length_map.setdefault(request.prompt_len, []).append(i) - ret = {} - for length, indices in length_map.items(): - bucket_size = (int(length / 128) + 1) * 128 - while len(indices) > 0: - i = indices.pop(0) - if ret.get(bucket_size, None) is None: - ret[bucket_size] = [] - ret[bucket_size].append(filtered_dataset[i]) - remain_len = bucket_size - length - while remain_len > 0: - if length_map.get(remain_len, None) is not None and len(length_map[remain_len]) > 0: - j = length_map[remain_len].pop(0) - ret[bucket_size].append(filtered_dataset[j]) - break - else: - remain_len -= 1 - # sort ret by key - ret = dict(sorted(ret.items(), key=lambda x: x[0])) - print("!!!!!!!!!!!!!!!sorted requests:", [(bucket_size, [i.prompt_len for i in req_list]) for bucket_size, req_list in ret.items()]) - filtered_dataset = [req for data in ret.items() for req in data[1]] break # Only keep the first two turns of each conversation. @@ -158,6 +135,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=multi_modal_data)) + prompt_lens.append(prompt_len) + print("!!!!prompt length are: ", pd.Series(prompt_lens).describe()) # for i, data in enumerate(filtered_dataset): # print(i, data.prompt) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 2617f053b3b5d..3eecda15d46ad 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -30,7 +30,7 @@ "vLLM will use native implementation.") -def prompt_attention( +def prompt_fsdpa( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -48,36 +48,20 @@ def prompt_attention( value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - #if attn_bias is not None or fsdpa_op is None: - if fsdpa_op is None: - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) - if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(1) - attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) + VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( + 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' + # TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: + key = ops.repeat_kv(key, int(query_heads // kv_heads)) + value = ops.repeat_kv(value, int(query_heads // kv_heads)) if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = matmul_av_op(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) - else: - VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( - 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' - # TODO: remove after fusedsdpa fix for query_heads != kv_heads - if query_heads != kv_heads: - if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: - key = ops.repeat_kv(key, int(query_heads // kv_heads)) - value = ops.repeat_kv(value, int(query_heads // kv_heads)) - if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(1) - softmax_mode = 'fast' - recompute_mode = True - attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, - scale, softmax_mode, recompute_mode, None, - 'right') + attn_bias = attn_bias.unsqueeze(1) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, + scale, softmax_mode, recompute_mode, None, + 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights @@ -308,13 +292,13 @@ def forward( attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) - # elif enable_merged_prefill: - # pass + elif enable_merged_prefill: + pass else: attn_bias = None - if enable_merged_prefill: - prompt_attn_func = prompt_attention + if enable_merged_prefill and self.prefill_use_fusedsdpa: + prompt_attn_func = prompt_fsdpa else: prompt_attn_func = ops.prompt_attention out = prompt_attn_func( diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c1ee5fc873e33..24a22d9ea4476 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -719,7 +719,7 @@ def __init__( self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' self.enable_cpu_merged_prefill_attn = os.environ.get('VLLM_CPU_MERGED_PREFILL_ATTN', - 'false').lower() == 'true' + 'true').lower() == 'true' if self.enable_merged_prefill: self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -1288,35 +1288,88 @@ def _prepare_prompt_merged( device='cpu') ##### Create attn_bias in CPU ##### - if self.enable_cpu_merged_prefill_attn: - #TODO: Support batch_size > 1 - batch_size = 1 - max_seq_len = merged_prompt_len - device = 'cpu' + def create_causal_attention_mask(batch_size, max_seq_len, seq_lens, device, dtype=torch.float32): + """ + Create a causal attention mask tensor for variable-length sequences in a batch. + + Args: + batch_size (int): Number of sequences in the batch. + max_seq_len (int): Maximum sequence length. + seq_lens (list[int]): List of sequence lengths for each sequence in the batch. + device (torch.device): Device to create the tensor on (e.g., 'cuda' or 'cpu'). + dtype (torch.dtype): Data type for the final attention mask. + + Returns: + torch.Tensor: A causal attention mask tensor. + """ + # Initialize a mask tensor with all ones causal_attn_mask_tensor = torch.ones( (batch_size, max_seq_len, max_seq_len), dtype=torch.bool, - device=device) + device=device + ) + start = 0 for i in range(batch_size): - for seq_len in seq_lens: - # create triangular mask for each sequence - causal_mask = torch.triu(torch.ones((seq_len, seq_len), - device=device, - dtype=torch.bool), - diagonal=1) - causal_attn_mask_tensor[i][start:start + seq_len, start:start + - seq_len].logical_and_(causal_mask) + seq_len = seq_lens[i] + # Create a triangular mask for the current sequence + causal_mask = torch.triu( + torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), + diagonal=1 + ) + # Apply the causal mask to the corresponding part of the tensor + causal_attn_mask_tensor[i, start:start + seq_len, start:start + seq_len].logical_and_( + ~causal_mask + ) + start += seq_len + + # Convert the boolean mask to the desired dtype and apply -10000 (masked value) + causal_attn_mask_tensor = torch.zeros_like( + causal_attn_mask_tensor, + device=device, + dtype=dtype + ).masked_fill_(causal_attn_mask_tensor, -10000) + + return causal_attn_mask_tensor + + def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): + """ + Create a causal attention mask for variable-length sequences in a batch using Python lists. + + Args: + batch_size (int): Number of sequences in the batch. + max_seq_len (int): Maximum sequence length. + seq_lens (list[int]): List of sequence lengths for each sequence in the batch. + + Returns: + list: A 3D causal attention mask as a list of lists. + """ + # Initialize the mask tensor with all ones (boolean values) + causal_attn_mask_tensor = [[[True for _ in range(max_seq_len)] for _ in range(max_seq_len)] for _ in range(batch_size)] + + for i in range(batch_size): + start = 0 + for j in range(len(seq_lens)): + seq_len = seq_lens[j] + # Create a triangular causal mask for the current sequence + causal_mask = [[False if col <= row else True for col in range(seq_len)] for row in range(seq_len)] + + # Apply the causal mask to the corresponding positions in the batch mask + for row in range(seq_len): + for col in range(seq_len): + causal_attn_mask_tensor[i][start + row][start + col] = causal_mask[row][col] start += seq_len - causal_attn_mask_tensor = ( - torch.zeros_like(causal_attn_mask_tensor, - device=device, - dtype=self.model_config.dtype).masked_fill_( - causal_attn_mask_tensor, -10000) - ) # should be math(-inf) but -10000 is used for numerical stability - causal_attn_mask_tensor = causal_attn_mask_tensor.view( - causal_attn_mask_tensor.shape[0], 1, - causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + + # Convert True/False to a mask value (-10000 for masked positions, 0 for others) + final_mask = [[[0 if not cell else -10000 for cell in row] for row in matrix] for matrix in causal_attn_mask_tensor] + final_mask_tensor = torch.tensor(final_mask, dtype=self.model_config.dtype) + final_mask_tensor = final_mask_tensor.view(final_mask_tensor.size(0), 1, final_mask_tensor.size(1), final_mask_tensor.size(2)) + + return final_mask_tensor + if self.enable_cpu_merged_prefill_attn: + #TODO: Support batch_size > 1 + batch_size = 1 + causal_attn_mask_tensor = create_causal_attention_mask_with_python(batch_size, merged_prompt_len, seq_lens) causal_attn_mask_tensor = causal_attn_mask_tensor.to(self.device, non_blocking=True) else: causal_attn_mask_tensor = None From 774c13c7e767db9c1c66c7204e0cddc3f6c643ee Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 9 Jan 2025 01:58:29 +0200 Subject: [PATCH 20/25] 111 Signed-off-by: Chendi Xue --- vllm/worker/hpu_model_runner.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 24a22d9ea4476..27af781b2e38a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,6 +11,7 @@ import math import os import time +import copy from array import array from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, @@ -335,7 +336,7 @@ def _set_merged_attn_bias( if attn_metadata.attn_bias is not None: return attn_metadata #TODO: Support batch_size > 1 - seq_lens = attn_metadata.seq_lens_tensor.tolist() + seq_lens = attn_metadata.seq_lens_tensor.tolist() #[6, 8, 5, 7 ] causal_attn_mask_tensor = torch.ones( (batch_size, max_seq_len, max_seq_len), dtype=torch.bool, @@ -720,6 +721,8 @@ def __init__( 'false').lower() == 'true' self.enable_cpu_merged_prefill_attn = os.environ.get('VLLM_CPU_MERGED_PREFILL_ATTN', 'true').lower() == 'true' + if self.enable_merged_prefill and self.enable_cpu_merged_prefill_attn: + self.causal_attn_mask_tensor_cache = dict() if self.enable_merged_prefill: self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -1345,24 +1348,23 @@ def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): list: A 3D causal attention mask as a list of lists. """ # Initialize the mask tensor with all ones (boolean values) - causal_attn_mask_tensor = [[[True for _ in range(max_seq_len)] for _ in range(max_seq_len)] for _ in range(batch_size)] + if batch_size not in self.causal_attn_mask_tensor_cache: + self.causal_attn_mask_tensor_cache[batch_size] = {} + if max_seq_len not in self.causal_attn_mask_tensor_cache[batch_size]: + causal_attn_mask_tensor = [[[-10000 for _ in range(max_seq_len)] for _ in range(max_seq_len)] for _ in range(batch_size)] + self.causal_attn_mask_tensor_cache[batch_size][max_seq_len] = copy.copy(causal_attn_mask_tensor) + else: + causal_attn_mask_tensor = copy.copy(self.causal_attn_mask_tensor_cache[batch_size][max_seq_len]) for i in range(batch_size): start = 0 for j in range(len(seq_lens)): - seq_len = seq_lens[j] - # Create a triangular causal mask for the current sequence - causal_mask = [[False if col <= row else True for col in range(seq_len)] for row in range(seq_len)] - - # Apply the causal mask to the corresponding positions in the batch mask + seq_len = seq_lens[j] for row in range(seq_len): - for col in range(seq_len): - causal_attn_mask_tensor[i][start + row][start + col] = causal_mask[row][col] + causal_attn_mask_tensor[i][start + row][start:start + row + 1] = [0] * (row + 1) start += seq_len - # Convert True/False to a mask value (-10000 for masked positions, 0 for others) - final_mask = [[[0 if not cell else -10000 for cell in row] for row in matrix] for matrix in causal_attn_mask_tensor] - final_mask_tensor = torch.tensor(final_mask, dtype=self.model_config.dtype) + final_mask_tensor = torch.tensor(causal_attn_mask_tensor, dtype=self.model_config.dtype) final_mask_tensor = final_mask_tensor.view(final_mask_tensor.size(0), 1, final_mask_tensor.size(1), final_mask_tensor.size(2)) return final_mask_tensor From 3826c1df5b385e6da0d8f4a46ee05f8df8b71eba Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 9 Jan 2025 05:56:35 +0200 Subject: [PATCH 21/25] update hpu attn Signed-off-by: Chendi Xue Signed-off-by: Chendi.Xue --- vllm/worker/hpu_model_runner.py | 53 +++++++++++++++------------------ 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 27af781b2e38a..f4d87c5ef7f3a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -327,40 +327,34 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def _set_merged_attn_bias( self, attn_metadata, - batch_size, max_seq_len, + eos_indices, device, + dtype ): # create a 2D causal attn mask to ensure I can only attend to the past if attn_metadata is None or not attn_metadata.is_prompt: return attn_metadata if attn_metadata.attn_bias is not None: return attn_metadata #TODO: Support batch_size > 1 - seq_lens = attn_metadata.seq_lens_tensor.tolist() #[6, 8, 5, 7 ] - causal_attn_mask_tensor = torch.ones( - (batch_size, max_seq_len, max_seq_len), - dtype=torch.bool, - device=device) - start = 0 - for i in range(batch_size): - for seq_len in seq_lens: - # create triangular mask for each sequence - causal_mask = torch.triu(torch.ones((seq_len, seq_len), - device=device, - dtype=torch.bool), - diagonal=1) - causal_attn_mask_tensor[i][start:start + seq_len, start:start + - seq_len].logical_and_(causal_mask) - start += seq_len - causal_attn_mask_tensor = ( - torch.zeros_like(causal_attn_mask_tensor, - device=device, - dtype=self.dtype).masked_fill_( - causal_attn_mask_tensor, -10000) - ) # should be math(-inf) but -10000 is used for numerical stability + # get length of each sequence + reps = attn_metadata.seq_lens_tensor + # get indices of all EOS tokens + # repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence) + repeated_idx_small = torch.repeat_interleave(eos_indices, reps, dim=0) + repeated_idx = torch.zeros(max_seq_len, dtype=dtype, device=device) + repeated_idx[:repeated_idx_small.size(0)] = repeated_idx_small + repeated_idx = repeated_idx.view(1,-1).expand(max_seq_len, -1) + # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1 + mask_indices = torch.arange(max_seq_len, dtype=dtype).view(-1,1).expand(-1, max_seq_len) + # create causal mask and additionally mask out all tokens from preceeding sequences + mask = mask_indices.le(repeated_idx) + causal_mask = torch.ones(max_seq_len, max_seq_len, dtype=torch.bool, device=device).tril() + causal_mask = causal_mask.logical_and(mask) + # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = torch.zeros_like(causal_mask, device=device, dtype=dtype).masked_fill_(~causal_mask, -10000) causal_attn_mask_tensor = causal_attn_mask_tensor.view( - causal_attn_mask_tensor.shape[0], 1, - causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + 1, 1, causal_attn_mask_tensor.shape[0], causal_attn_mask_tensor.shape[1]) attn_metadata = attn_metadata._replace( attn_bias=causal_attn_mask_tensor) @@ -421,11 +415,11 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): return metadata def _update_metadata(self, attn_metadata, batch_size, seq_len, device, - dtype): + dtype, eos_indices): if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: attn_metadata = self._set_merged_attn_bias(attn_metadata, - batch_size, seq_len, - device) + seq_len, eos_indices, + device, dtype) elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) @@ -459,7 +453,7 @@ def forward(self, *args, **kwargs): input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + input_ids.device, self.dtype, selected_token_indices) LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) @@ -1390,6 +1384,7 @@ def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): non_blocking=True) actual_num_prefills_tensor = actual_num_prefills_tensor.to( self.device, non_blocking=True) + attn_metadata = self.attn_backend.make_metadata( is_prompt=True, enable_merged_prefill=True, From a906f364b4f16083b9618431e664e857b6b8c747 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 9 Jan 2025 06:57:57 +0000 Subject: [PATCH 22/25] update Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 1 + vllm/worker/hpu_model_runner.py | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 3eecda15d46ad..7892522209151 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -121,6 +121,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): context_lens_tensor: Optional[torch.Tensor] enable_merged_prefill: bool = False actual_num_prefills: Optional[torch.Tensor] = None + repeated_idx_tensor: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f4d87c5ef7f3a..8b70be7ea738a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -328,7 +328,6 @@ def _set_merged_attn_bias( self, attn_metadata, max_seq_len, - eos_indices, device, dtype ): # create a 2D causal attn mask to ensure I can only attend to the past @@ -338,15 +337,9 @@ def _set_merged_attn_bias( return attn_metadata #TODO: Support batch_size > 1 # get length of each sequence - reps = attn_metadata.seq_lens_tensor - # get indices of all EOS tokens - # repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence) - repeated_idx_small = torch.repeat_interleave(eos_indices, reps, dim=0) - repeated_idx = torch.zeros(max_seq_len, dtype=dtype, device=device) - repeated_idx[:repeated_idx_small.size(0)] = repeated_idx_small - repeated_idx = repeated_idx.view(1,-1).expand(max_seq_len, -1) + repeated_idx = attn_metadata.repeated_idx_tensor.view(1,-1).expand(max_seq_len, -1) # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1 - mask_indices = torch.arange(max_seq_len, dtype=dtype).view(-1,1).expand(-1, max_seq_len) + mask_indices = torch.arange(max_seq_len, dtype=dtype, device=device).view(-1,1).expand(-1, max_seq_len) # create causal mask and additionally mask out all tokens from preceeding sequences mask = mask_indices.le(repeated_idx) causal_mask = torch.ones(max_seq_len, max_seq_len, dtype=torch.bool, device=device).tril() @@ -415,10 +408,10 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): return metadata def _update_metadata(self, attn_metadata, batch_size, seq_len, device, - dtype, eos_indices): + dtype): if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: attn_metadata = self._set_merged_attn_bias(attn_metadata, - seq_len, eos_indices, + seq_len, device, dtype) elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, @@ -453,7 +446,7 @@ def forward(self, *args, **kwargs): input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype, selected_token_indices) + input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) @@ -714,7 +707,7 @@ def __init__( self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' self.enable_cpu_merged_prefill_attn = os.environ.get('VLLM_CPU_MERGED_PREFILL_ATTN', - 'true').lower() == 'true' + 'false').lower() == 'true' if self.enable_merged_prefill and self.enable_cpu_merged_prefill_attn: self.causal_attn_mask_tensor_cache = dict() if self.enable_merged_prefill: @@ -1249,9 +1242,13 @@ def _prepare_prompt_merged( max_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), self.block_size) - + # get cumsum of seq_lens + repeated_idx = list(itertools.accumulate(seq_lens)) + repeated_idx = [[idx - 1] * seq_len for idx, seq_len in zip(repeated_idx, seq_lens)] + repeated_idx = list(itertools.chain.from_iterable(repeated_idx)) + [0] * (merged_prompt_len - sum(seq_lens)) prefix_block_list_tensor = None + repeated_idx_tensor = torch.tensor(repeated_idx, dtype=torch.long, device='cpu') input_tokens_tensor = make_tensor_with_pad(input_tokens_merged, max_len=merged_prompt_len, pad=0, @@ -1382,6 +1379,7 @@ def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) context_lens_tensor = context_lens_tensor.to(self.device, non_blocking=True) + repeated_idx_tensor = repeated_idx_tensor.to(self.device, non_blocking=True) actual_num_prefills_tensor = actual_num_prefills_tensor.to( self.device, non_blocking=True) @@ -1389,6 +1387,7 @@ def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): is_prompt=True, enable_merged_prefill=True, actual_num_prefills=actual_num_prefills_tensor, + repeated_idx_tensor=repeated_idx_tensor, block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, @@ -1801,6 +1800,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'context_lens_tensor', 'enable_merged_prefill', 'actual_num_prefills', + 'repeated_idx_tensor', 'block_list', 'block_mapping', 'block_usage', From 67a292381662b4ec7835a751e063378cdc6b8d3d Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 9 Jan 2025 17:11:52 +0200 Subject: [PATCH 23/25] clean up codes Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 4 +- vllm/worker/hpu_model_runner.py | 112 +--------------------------- 2 files changed, 4 insertions(+), 112 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 7892522209151..1d5b83c1e61f2 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -272,9 +272,9 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) + block_offsets) value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) + block_offsets) if attn_metadata.is_prompt: # Prompt run. diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 8b70be7ea738a..b92c49dd7c154 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -233,15 +233,6 @@ def generate_prompt_buckets(self): f"{list(sorted(self.global_state.prompt_buckets))}") print(msg) - # msg = (f"Omitted {len(prompt_omitted_buckets)} " - # "prompt buckets due to exceeded token budget " - # f"(max_num_batched_tokens={self.max_num_batched_tokens})") - # print(msg) - - # msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - # print(msg) - - class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager, layer_names): @@ -706,10 +697,6 @@ def __init__( self._mem_margin: Optional[int] = None self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', 'false').lower() == 'true' - self.enable_cpu_merged_prefill_attn = os.environ.get('VLLM_CPU_MERGED_PREFILL_ATTN', - 'false').lower() == 'true' - if self.enable_merged_prefill and self.enable_cpu_merged_prefill_attn: - self.causal_attn_mask_tensor_cache = dict() if self.enable_merged_prefill: self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -1212,13 +1199,6 @@ def _prepare_prompt_merged( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - #input_tokens - #input_positions - #slot_mapping - #seq_lens - #context_lens - #prefix_block_list - slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID] slot_mapping = [slot_mapping_merged] @@ -1227,7 +1207,6 @@ def _prepare_prompt_merged( input_positions_merged = list( itertools.chain.from_iterable(input_positions)) input_positions_merged = [input_positions_merged] - context_lens_merged = [sum(context_lens)] total_seq_lens = [sum(seq_lens)] total_query_lens = [sum(query_lens)] @@ -1239,9 +1218,6 @@ def _prepare_prompt_merged( merged_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), self.block_size) - max_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), - self.block_size) # get cumsum of seq_lens repeated_idx = list(itertools.accumulate(seq_lens)) repeated_idx = [[idx - 1] * seq_len for idx, seq_len in zip(repeated_idx, seq_lens)] @@ -1270,7 +1246,7 @@ def _prepare_prompt_merged( dtype=torch.long, device='cpu') - max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16')) + max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '8')) max_prefill_bs = max(max_prefill_bs, len(seq_lens)) seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens)) context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens)) @@ -1281,91 +1257,7 @@ def _prepare_prompt_merged( dtype=torch.long, device='cpu') ##### Create attn_bias in CPU ##### - - def create_causal_attention_mask(batch_size, max_seq_len, seq_lens, device, dtype=torch.float32): - """ - Create a causal attention mask tensor for variable-length sequences in a batch. - - Args: - batch_size (int): Number of sequences in the batch. - max_seq_len (int): Maximum sequence length. - seq_lens (list[int]): List of sequence lengths for each sequence in the batch. - device (torch.device): Device to create the tensor on (e.g., 'cuda' or 'cpu'). - dtype (torch.dtype): Data type for the final attention mask. - - Returns: - torch.Tensor: A causal attention mask tensor. - """ - # Initialize a mask tensor with all ones - causal_attn_mask_tensor = torch.ones( - (batch_size, max_seq_len, max_seq_len), - dtype=torch.bool, - device=device - ) - - start = 0 - for i in range(batch_size): - seq_len = seq_lens[i] - # Create a triangular mask for the current sequence - causal_mask = torch.triu( - torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), - diagonal=1 - ) - # Apply the causal mask to the corresponding part of the tensor - causal_attn_mask_tensor[i, start:start + seq_len, start:start + seq_len].logical_and_( - ~causal_mask - ) - start += seq_len - - # Convert the boolean mask to the desired dtype and apply -10000 (masked value) - causal_attn_mask_tensor = torch.zeros_like( - causal_attn_mask_tensor, - device=device, - dtype=dtype - ).masked_fill_(causal_attn_mask_tensor, -10000) - - return causal_attn_mask_tensor - - def create_causal_attention_mask_with_python(batch_size, max_seq_len, seq_lens): - """ - Create a causal attention mask for variable-length sequences in a batch using Python lists. - - Args: - batch_size (int): Number of sequences in the batch. - max_seq_len (int): Maximum sequence length. - seq_lens (list[int]): List of sequence lengths for each sequence in the batch. - - Returns: - list: A 3D causal attention mask as a list of lists. - """ - # Initialize the mask tensor with all ones (boolean values) - if batch_size not in self.causal_attn_mask_tensor_cache: - self.causal_attn_mask_tensor_cache[batch_size] = {} - if max_seq_len not in self.causal_attn_mask_tensor_cache[batch_size]: - causal_attn_mask_tensor = [[[-10000 for _ in range(max_seq_len)] for _ in range(max_seq_len)] for _ in range(batch_size)] - self.causal_attn_mask_tensor_cache[batch_size][max_seq_len] = copy.copy(causal_attn_mask_tensor) - else: - causal_attn_mask_tensor = copy.copy(self.causal_attn_mask_tensor_cache[batch_size][max_seq_len]) - - for i in range(batch_size): - start = 0 - for j in range(len(seq_lens)): - seq_len = seq_lens[j] - for row in range(seq_len): - causal_attn_mask_tensor[i][start + row][start:start + row + 1] = [0] * (row + 1) - start += seq_len - - final_mask_tensor = torch.tensor(causal_attn_mask_tensor, dtype=self.model_config.dtype) - final_mask_tensor = final_mask_tensor.view(final_mask_tensor.size(0), 1, final_mask_tensor.size(1), final_mask_tensor.size(2)) - - return final_mask_tensor - if self.enable_cpu_merged_prefill_attn: - #TODO: Support batch_size > 1 - batch_size = 1 - causal_attn_mask_tensor = create_causal_attention_mask_with_python(batch_size, merged_prompt_len, seq_lens) - causal_attn_mask_tensor = causal_attn_mask_tensor.to(self.device, non_blocking=True) - else: - causal_attn_mask_tensor = None + causal_attn_mask_tensor = None ######################## # Note: num_prefill_tokens is calculated using the length of # input_tokens after padding. From a2d300769fef5b1927132e3cb733ebaf0d218e59 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 10 Jan 2025 20:54:15 +0000 Subject: [PATCH 24/25] Fix fusedSDPA fp8 70B issue Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1d5b83c1e61f2..e6c2925beb1fe 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -49,14 +49,12 @@ def prompt_fsdpa( query_heads = query.size(1) kv_heads = key.size(1) VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( - 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' + 'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' # TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: key = ops.repeat_kv(key, int(query_heads // kv_heads)) value = ops.repeat_kv(value, int(query_heads // kv_heads)) - if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(1) softmax_mode = 'fast' recompute_mode = True attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, From c2d4c3344ae6aff591d42b034cad10f093a6c3d2 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sun, 12 Jan 2025 07:46:21 +0000 Subject: [PATCH 25/25] increase absolute value for attn_bias to get better accuracy Signed-off-by: Chendi.Xue --- vllm/worker/hpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b92c49dd7c154..0edeb1ffd23b5 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -336,7 +336,9 @@ def _set_merged_attn_bias( causal_mask = torch.ones(max_seq_len, max_seq_len, dtype=torch.bool, device=device).tril() causal_mask = causal_mask.logical_and(mask) # should be math(-inf) but -10000 is used for numerical stability - causal_attn_mask_tensor = torch.zeros_like(causal_mask, device=device, dtype=dtype).masked_fill_(~causal_mask, -10000) + #causal_attn_mask_tensor = torch.zeros_like(causal_mask, device=device, dtype=dtype).masked_fill_(~causal_mask, -math.inf) + #print("min: ", torch.finfo(dtype).min) + causal_attn_mask_tensor = torch.zeros_like(causal_mask, device=device, dtype=dtype).masked_fill_(~causal_mask, torch.finfo(dtype).min) causal_attn_mask_tensor = causal_attn_mask_tensor.view( 1, 1, causal_attn_mask_tensor.shape[0], causal_attn_mask_tensor.shape[1])