diff --git a/byte_infer_perf/llm_perf/README.md b/byte_infer_perf/llm_perf/README.md index f0d0cadb..4a07df9e 100644 --- a/byte_infer_perf/llm_perf/README.md +++ b/byte_infer_perf/llm_perf/README.md @@ -54,10 +54,8 @@ Vendors can refer to this document for guidance on building backend: [Byte LLM P ## Models The following models are planned to be supported: * [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) -* [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) +* [shenzhi-wang/Llama3-70B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-70B-Chinese-Chat) * [tiiuae/falcon-180B](https://huggingface.co/tiiuae/falcon-180B) + - test_accuracy is unavailable temporarily. * [mistralai/Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1) - -The following models are outdated and will be removed in future vesions: -* [hfl/chinese-llama-2-13b](https://huggingface.co/hfl/chinese-llama-2-13b) - + - test_accuracy is unavailable temporarily. diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py index 83f979e6..af04d055 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py @@ -1,6 +1,7 @@ import os import time from multiprocessing import Queue +from typing import List import torch import torch.nn as nn @@ -9,6 +10,61 @@ from llm_perf.core.mp_engine import CoreMpEngine from llm_perf.utils.logger import logger + + +# context: +# input_ids: [1, s_q] +# attention_mask = [1, s_q] +# full_attention_mask = [1, 1, s_q, s_kv] (sq == s_kv) +def get_context_masks( + input_ids : torch.Tensor, + padding_mask : torch.Tensor +): + # input_ids: [1, q_len] + # padding_mask = [1, q_len] + _, q_len = input_ids.shape + + # [1, q_len, q_len] + full_attention_mask = torch.ones( + 1, q_len, q_len, + device=input_ids.device + ) + full_attention_mask.tril_() + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +# decode +# input_ids: [bs, 1] +# attention_mask = [bs, 1] +# full_attention_mask = [bs, 1, 1, s_kv] +def get_decode_masks( + input_ids : torch.Tensor, + all_kv_len: List[int] +): + # input_ids: [batch_size, 1] + # padding_mask: [batch_size, 1 + max_kv_len] + batch_size, q_len = input_ids.shape + max_qkv_len = q_len + max(all_kv_len) + + # [batch_size, 1, max_qkv_len] + padding_mask = [] + for i in range(batch_size): + cur_qkv_len = q_len + all_kv_len[i] + mask_per_batch = [1] * cur_qkv_len + [0] * (max_qkv_len - cur_qkv_len) + padding_mask.append(mask_per_batch) + full_attention_mask = torch.tensor( + padding_mask, + device=input_ids.device + ).unsqueeze_(1) + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + class GpuMpEngine(CoreMpEngine): def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None: super().__init__(world_size, model_impl, xpu_cfg) @@ -25,6 +81,18 @@ def build_inputs(self, forward_inputs): forward_inputs["attention_mask"] = torch.tensor( forward_inputs["attention_mask"] ).cuda() + + is_context = forward_inputs["is_context"] + if is_context: + forward_inputs["full_attention_mask"] = get_context_masks( + forward_inputs["input_ids"], + forward_inputs["attention_mask"] + ) + else: + forward_inputs["full_attention_mask"] = get_decode_masks( + forward_inputs["input_ids"], + forward_inputs["all_kv_len"] + ) return forward_inputs diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py index 5f62ed92..3401101e 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py @@ -12,11 +12,15 @@ import torch.nn as nn from .gpu_chatglm2 import GPUChatGLM2 +from .gpu_llama3 import GPULlama from .gpu_falcon import GPUFalcon +from .gpu_mixtral import GPUMixtral from llm_perf.utils.logger import logger __all__ = { "chatglm2": GPUChatGLM2, - "falcon": GPUFalcon + "llama3": GPULlama, + "falcon": GPUFalcon, + "mixtral": GPUMixtral } \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon.py deleted file mode 100644 index c4031d64..00000000 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon.py +++ /dev/null @@ -1,1360 +0,0 @@ -# coding=utf-8 -# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Falcon model.""" - -import os -import math -import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss -from torch.nn import functional as F -import torch.distributed as dist - -from transformers.activations import get_activation -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) -from transformers.models.falcon.configuration_falcon import FalconConfig - - -if TYPE_CHECKING: - from transformers.configuration_utils import PretrainedConfig - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - -logger = logging.get_logger(__name__) - -from transformers.models.deprecated._archive_maps import FALCON_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 - - -_CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b" -_CONFIG_FOR_DOC = "FalconConfig" - - -# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations. -# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model. -class FalconLinear(nn.Linear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - hidden_states = input @ self.weight.T - if self.bias is None: - return hidden_states - return hidden_states + self.bias - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon -class FalconRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon -# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) -class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): - """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon -# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) -class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): - """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: - batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor( - 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 - ) - powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 - ) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 - arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] - alibi = slopes[..., None].bfloat16() * arange_tensor - return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) - - -# Copied from transformers.models.bloom.modeling_bloom.dropout_add -def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: - """ - Dropout add function - - Args: - x (`torch.tensor`, *required*): - input tensor - residual (`torch.tensor`, *required*): - residual tensor - prob (`float`, *required*): - dropout probability - training (`bool`, *required*): - training mode - """ - out = F.dropout(x, p=prob, training=training) - out = residual + out - return out - - -class FalconAttention(nn.Module): - def __init__(self, config: FalconConfig): - super().__init__() - - # dist info - self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) - self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) - - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.split_size = self.hidden_size - self.hidden_dropout = config.hidden_dropout - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self._use_sdpa = config._attn_implementation == "sdpa" - - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError( - f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" - f" {self.num_heads})." - ) - - if config.rotary: - self._init_rope() - - # Layer-wise attention scaling - self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) - self.beta = self.inv_norm_factor - if config.new_decoder_architecture: - qkv_out_dim = (config.num_kv_heads * 2 // self.mp_size + config.num_attention_heads // self.mp_size) * self.head_dim - elif config.multi_query: - qkv_out_dim = self.hidden_size // self.mp_size + 2 * self.head_dim - else: - qkv_out_dim = 3 * self.hidden_size // self.mp_size - - - self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) - self.new_decoder_architecture = config.new_decoder_architecture - self.multi_query = config.multi_query - self.dense = FalconLinear(self.hidden_size // self.mp_size, self.hidden_size, bias=config.bias) - self.attention_dropout = nn.Dropout(config.attention_dropout) - self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - - # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = FalconRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = FalconLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - if self.new_decoder_architecture: - batch, seq_len, _ = fused_qkv.shape - qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) - query = qkv[:, :, :, :-2] - key = qkv[:, :, :, [-2]] - value = qkv[:, :, :, [-1]] - key = torch.broadcast_to(key, query.shape) - value = torch.broadcast_to(value, query.shape) - - query, key, value = [x.flatten(2, 3) for x in (query, key, value)] - return query, key, value - elif not self.multi_query: - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] - else: - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) - return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] - - # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads - def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: - """ - Merge heads together over the last dimension - - Args: - x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] - - Returns: - torch.tensor: [batch_size, seq_length, num_heads * head_dim] - """ - # What we want to achieve is: - # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim - batch_size_and_num_heads, seq_length, _ = x.shape - batch_size = batch_size_and_num_heads // self.num_heads - - # First view to decompose the batch size - # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim - x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) - - # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim - x = x.permute(0, 2, 1, 3) - - # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim - return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - kv_length = key_layer.shape[-2] - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None: - # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - if alibi is None: - if self._use_sdpa and not output_attentions: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - - attention_scores = None - else: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) - - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). - attn_output = attention_scores @ value_layer - - attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_scores - else: - return attn_output, present - - else: - if self._use_sdpa and not output_attentions and head_mask is None: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) - - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - - -class FalconFlashAttention2(FalconAttention): - """ - Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None and use_cache: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - past_key_value = (key_layer, value_layer) if use_cache else None - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_layer = query_layer.transpose(1, 2) - key_layer = key_layer.transpose(1, 2) - value_layer = value_layer.transpose(1, 2) - - if alibi is not None: - raise ValueError("`alibi` is not supported when `use_flash_attn` is True") - - attn_dropout = self.config.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_layer.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.query_key_value.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_layer = query_layer.to(target_dtype) - key_layer = key_layer.to(target_dtype) - value_layer = value_layer.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_layer, key_layer, value_layer, attention_mask, query_length, dropout=attn_dropout - ) - - attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - attn_output = self.dense(attn_weights) - - if not output_attentions: - attn_weights = None - - return attn_output, past_key_value, attn_weights - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class FalconMLP(nn.Module): - def __init__(self, config: FalconConfig): - super().__init__() - - # dist info - self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) - self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) - - hidden_size = config.hidden_size - - self.dense_h_to_4h = FalconLinear( - hidden_size, - config.ffn_hidden_size // self.mp_size, - bias=config.bias - ) - self.act = get_activation(config.activation) - self.dense_4h_to_h = FalconLinear( - config.ffn_hidden_size // self.mp_size, - hidden_size, - bias=config.bias - ) - - self.hidden_dropout = config.hidden_dropout - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.act(self.dense_h_to_4h(x)) - x = self.dense_4h_to_h(x) - if self.mp_size > 1: - dist.all_reduce(x) - return x - - -FALCON_ATTENTION_CLASSES = { - "eager": FalconAttention, - "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA - "flash_attention_2": FalconFlashAttention2, -} - - -class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig): - super().__init__() - hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - - self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config) - self.mlp = FalconMLP(config) - self.hidden_dropout = config.hidden_dropout - self.config = config - - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: - self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - if self.config.new_decoder_architecture: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) - else: - attention_layernorm_out = self.input_layernorm(hidden_states) - - # Self attention. - attn_outputs = self.self_attention( - attention_layernorm_out, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) - - attention_output = attn_outputs[0] - - if not self.config.new_decoder_architecture: - if self.config.parallel_attn: - mlp_layernorm_out = attention_layernorm_out - else: - residual = dropout_add( - attention_output, residual, self.config.attention_dropout, training=self.training - ) - mlp_layernorm_out = self.post_attention_layernorm(residual) - - outputs = attn_outputs[1:] - - # MLP. - mlp_output = self.mlp(mlp_layernorm_out) - - if self.config.new_decoder_architecture or self.config.parallel_attn: - mlp_output += attention_output - - output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - -FALCON_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`FalconConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -FALCON_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` - (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. - - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - - Each element of `past_key_values` is a tuple (past_key, past_value): - - past_key: [batch_size * num_heads, head_dim, kv_length] - - past_value: [batch_size * num_heads, kv_length, head_dim] - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - - If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see - `past_key_values`). - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. -""" - - -class FalconPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = FalconConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = True - _no_split_modules = ["FalconDecoderLayer"] - _supports_flash_attn_2 = True - _supports_sdpa = True - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - if isinstance(module, nn.Linear) or isinstance(module, FalconLinear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": - # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0). - if hard_check_only: - if not is_torch_greater_or_equal_than_2_0: - raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.") - - if not is_torch_greater_or_equal_than_2_0: - return config - - _is_bettertransformer = getattr(cls, "use_bettertransformer", False) - if _is_bettertransformer: - return config - - if not hard_check_only: - config._attn_implementation = "sdpa" - return config - - -@add_start_docstrings( - "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.", - FALCON_START_DOCSTRING, -) -class FalconModel(FalconPreTrainedModel): - def __init__(self, config: FalconConfig): - super().__init__(config) - - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.use_alibi = config.alibi - - # Embedding + LN Embedding - self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - - # Transformer blocks - self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" - - # Final Layer Norm - self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings: torch.Tensor): - self.word_embeddings = new_embeddings - - @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - # Compute alibi tensor: check build_alibi_tensor documentation - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-2] - - if self.use_alibi: - mask = ( - torch.ones( - (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long - ) - if attention_mask is None - else attention_mask - ) - alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) - else: - alibi = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if alibi is None: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - elif head_mask is None: - alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - - # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # We take care to integrate alibi bias in the attention_mask here. - min_dtype = torch.finfo(alibi.dtype).min - attention_mask = torch.masked_fill( - alibi / math.sqrt(self.config.hidden_size // self.num_heads), - attention_mask < -1, - min_dtype, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1 and attention_mask.device.type == "cuda": - attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) - else: - # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - attention_mask, - position_ids, - head_mask[i], - layer_past, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -@add_start_docstrings( - "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).", - FALCON_START_DOCSTRING, -) -class FalconForCausalLM(FalconPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: FalconConfig): - super().__init__(config) - self.transformer = FalconModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings: torch.Tensor): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. - if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - return { - "input_ids": input_ids, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - - @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def _reorder_cache( - self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py index 637b25cf..246514c7 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py @@ -16,7 +16,7 @@ from llm_perf.core.ckpt_loader import CoreCkptLoader, ChatGLM2_ModelLoader from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader -from .chatglm2 import ChatGLMForConditionalGeneration, ChatGLMModel, ChatGLMConfig +from .modeling_chatglm2 import ChatGLMForConditionalGeneration, ChatGLMModel, ChatGLMConfig class GPUChatGLM2Loader(GpuCkptLoader): @@ -28,10 +28,8 @@ def __init__( ckpt_path: str = "" ): super().__init__(prefix, model, mp_size, mp_rank, ckpt_path) - self.model_config = model_config - def parallel_loader(self): self.state_dict = {} @@ -148,8 +146,6 @@ def __init__(self, xpu_cfg: Dict[str, Any]) -> None: self.transformer_model : ChatGLMForConditionalGeneration = None - - def init_inference(self): torch.cuda.set_device(self.local_rank) diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_falcon.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_falcon.py index 24172176..2cd7a734 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_falcon.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_falcon.py @@ -1,5 +1,4 @@ import os -import json import pathlib import torch @@ -13,22 +12,23 @@ from accelerate import init_empty_weights -from llm_perf.core.ckpt_loader import CoreCkptLoader, Falcon_ModelLoader from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader +from llm_perf.core.ckpt_loader import Falcon_ModelLoader from transformers import FalconConfig -from .falcon import FalconForCausalLM +from .modeling_falcon import FalconForCausalLM class GPUFalconLoader(GpuCkptLoader): def __init__( self, - prefix, - model, model_config, - mp_size=1, mp_rank=0, - ckpt_path=None, + model : FalconForCausalLM, + model_config : FalconConfig, + ckpt_path : str = "" ): - super().__init__(prefix, model, mp_size, mp_rank, ckpt_path) + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + super().__init__("", model, mp_size, local_rank, ckpt_path) self.model_config = model_config def parallel_loader(self): @@ -50,7 +50,23 @@ def parallel_loader(self): self.state_dict = model_loader.load_weight() def infusion_to_model(self): - pass + self.model.transformer.word_embeddings.weight = self.to_parameter(self.state_dict["transformer.word_embeddings.weight"]) + for i in range(self.model_config.num_hidden_layers): + self.model.transformer.h[i].ln_attn.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.ln_attn.weight"]) + self.model.transformer.h[i].ln_attn.bias = self.to_parameter(self.state_dict[f"transformer.h.{i}.ln_attn.bias"]) + + self.model.transformer.h[i].ln_mlp.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.ln_mlp.weight"]) + self.model.transformer.h[i].ln_mlp.bias = self.to_parameter(self.state_dict[f"transformer.h.{i}.ln_mlp.bias"]) + + self.model.transformer.h[i].self_attention.query_key_value.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.self_attention.query_key_value.weight"]) + self.model.transformer.h[i].self_attention.dense.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.self_attention.dense.weight"]) + + self.model.transformer.h[i].mlp.dense_h_to_4h.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"]) + self.model.transformer.h[i].mlp.dense_4h_to_h.weight = self.to_parameter(self.state_dict[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"]) + + self.model.transformer.ln_f.weight = self.to_parameter(self.state_dict["transformer.ln_f.weight"]) + self.model.transformer.ln_f.bias = self.to_parameter(self.state_dict["transformer.ln_f.bias"]) + self.model.lm_head.weight = self.to_parameter(self.state_dict["lm_head.weight"]) @@ -65,14 +81,12 @@ def __init__(self, xpu_cfg: Dict[str, Any]) -> None: self.model_path = self.model_config["model_path"] self.model_network = self.model_config["network"] - self.falcon_config = FalconConfig(**self.model_network) - # print(self.falcon_config) + self.falcon_config : FalconConfig = FalconConfig(**self.model_network) # dist config self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) - self.prefix = "" self.transformer_model : FalconForCausalLM = None @@ -91,22 +105,67 @@ def init_inference(self): check_memory_usage("Begin") with init_empty_weights(): - self.transformer_model = FalconForCausalLM( - self.falcon_config - ) + self.transformer_model = FalconForCausalLM(self.falcon_config) self.transformer_model.eval() + check_memory_usage("After build model") + self.load_weight(self.model_path) + + check_memory_usage("After load_weight") + + self.transformer_model.cuda() + + check_memory_usage("After model to device") + + self.kv_cache = self.init_kvcache(self.falcon_config.torch_dtype) + + dist.barrier() + + def load_weight(self, ckpt_path): - p_loader = GPUFalconLoader( - self.prefix, self.transformer_model, self.falcon_config, - self.mp_size, self.local_rank, - ckpt_path - ) + p_loader = GPUFalconLoader(self.transformer_model, self.falcon_config, ckpt_path) p_loader.parallel_loader() p_loader.infusion_to_model() + def init_kvcache(self, dtype): + max_batch_size = self.xpu_cfg["max_batch_size"] + num_layers = self.falcon_config.num_hidden_layers + max_seq_len = self.falcon_config.max_position_embeddings + hidden_size = self.falcon_config.hidden_size + q_head_num = self.falcon_config.num_attention_heads + kv_head_num = self.falcon_config.num_kv_heads + head_dim = hidden_size // q_head_num + + cur_device = self.transformer_model.device + + past_key_values = () + for i in range(num_layers): + # [max_batch_size, q_head_num, max_seq_len, head_dim] + # TODO: optimize to kv_head_num + + kv_shape = (max_batch_size, q_head_num // self.mp_size, max_seq_len, head_dim) + key_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + value_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + past_key_values += ((key_cache, value_cache),) + return past_key_values + + + def forward(self, inputs : Dict[str, torch.Tensor]): - pass \ No newline at end of file + model_outputs = self.transformer_model.forward( + **inputs, + past_key_values=self.kv_cache + ) + + # context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size] + # decode: [max_batch_size, 1] + logits = model_outputs.logits + + output_dict = { + "logits": logits + } + return output_dict + diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_llama3.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_llama3.py new file mode 100644 index 00000000..d633d862 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_llama3.py @@ -0,0 +1,162 @@ +import os +import pathlib + +import torch +import torch.nn as nn +import torch.distributed as dist + +from typing import Dict, Any +from llm_perf.utils.logger import logger +from llm_perf.utils.ps_utils import check_memory_usage +from llm_perf.utils.dist_utils import check_dist + +from accelerate import init_empty_weights + +from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader +from llm_perf.core.ckpt_loader import Llama_ModelLoader +from transformers import LlamaConfig +from .modeling_llama3 import LlamaForCausalLM + + +class GPULlamaLoader(GpuCkptLoader): + def __init__( + self, + model : LlamaForCausalLM, + model_config : LlamaConfig, + ckpt_path : str = "" + ): + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__("", model, mp_size, local_rank, ckpt_path) + self.model_config = model_config + + def parallel_loader(self): + self.state_dict = {} + + model_dir = pathlib.Path(self.ckpt_path).absolute() + if not model_dir.exists() or not model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{model_dir} not exists or is not a directory") + return + + split_model_dir = model_dir.joinpath(f"TP{self.mp_size}") + if not split_model_dir.exists() or not split_model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{split_model_dir} not exists or is not a directory, please split model first.") + return + + model_loader = Llama_ModelLoader(split_model_dir / f"device_{self.mp_rank}") + self.state_dict = model_loader.load_weight() + + def infusion_to_model(self): + self.model.model.embed_tokens.weight = self.to_parameter(self.state_dict["model.embed_tokens.weight"]) + for i in range(self.model_config.num_hidden_layers): + self.model.model.layers[i].input_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.input_layernorm.weight"]) + + self.model.model.layers[i].self_attn.q_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]) + self.model.model.layers[i].self_attn.k_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.k_proj.weight"]) + self.model.model.layers[i].self_attn.v_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.v_proj.weight"]) + self.model.model.layers[i].self_attn.o_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.o_proj.weight"]) + + self.model.model.layers[i].post_attention_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + + self.model.model.layers[i].mlp.gate_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.mlp.gate_proj.weight"]) + self.model.model.layers[i].mlp.up_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.mlp.up_proj.weight"]) + self.model.model.layers[i].mlp.down_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.mlp.down_proj.weight"]) + + self.model.model.norm.weight = self.to_parameter(self.state_dict["model.norm.weight"]) + self.model.lm_head.weight = self.to_parameter(self.state_dict["lm_head.weight"]) + + +class GPULlama(nn.Module): + def __init__(self, xpu_cfg: Dict[str, Any]) -> None: + super().__init__() + + self.xpu_cfg = xpu_cfg + self.model_config = xpu_cfg["model_config"] + + self.model_name = self.model_config["model_name"] + self.model_path = self.model_config["model_path"] + self.model_network = self.model_config["network"] + + self.llama_config : LlamaConfig = LlamaConfig(**self.model_network) + # print(self.llama_config) + + # dist config + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + self.transformer_model : LlamaForCausalLM = None + + + def init_inference(self): + torch.cuda.set_device(self.local_rank) + + if self.mp_size > 1: + logger.info(f"RANK: {self.local_rank} {self.mp_size} init_process_group...") + dist.init_process_group( + backend="nccl", + world_size=self.mp_size, + rank=self.local_rank + ) + check_dist() + + check_memory_usage("Begin") + + with init_empty_weights(): + self.transformer_model = LlamaForCausalLM(self.llama_config).to(self.llama_config.torch_dtype).eval() + + check_memory_usage("After build model") + + self.load_weight(self.model_path) + + check_memory_usage("After load_weight") + + self.transformer_model.cuda() + + check_memory_usage("After model to device") + + self.kv_cache = self.init_kvcache(self.llama_config.torch_dtype) + + dist.barrier() + + def load_weight(self, ckpt_path): + p_loader = GPULlamaLoader(self.transformer_model, self.llama_config, ckpt_path) + p_loader.parallel_loader() + p_loader.infusion_to_model() + + def init_kvcache(self, dtype): + max_batch_size = self.xpu_cfg["max_batch_size"] + num_layers = self.llama_config.num_hidden_layers + max_seq_len = self.llama_config.max_position_embeddings + hidden_size = self.llama_config.hidden_size + q_head_num = self.llama_config.num_attention_heads + kv_head_num = self.llama_config.num_key_value_heads + head_dim = hidden_size // q_head_num + + cur_device = self.transformer_model.device + + past_key_values = () + for i in range(num_layers): + kv_shape = (max_batch_size, kv_head_num // self.mp_size, max_seq_len, head_dim) + key_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + value_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + past_key_values += ((key_cache, value_cache),) + return past_key_values + + + def forward(self, inputs : Dict[str, torch.Tensor]): + model_outputs = self.transformer_model.forward( + **inputs, + past_key_values=self.kv_cache + ) + + # context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size] + # decode: [max_batch_size, 1] + logits = model_outputs.logits + + output_dict = { + "logits": logits + } + return output_dict \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_mixtral.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_mixtral.py new file mode 100644 index 00000000..6fd6384a --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_mixtral.py @@ -0,0 +1,164 @@ +import os +import pathlib + +import torch +import torch.nn as nn +import torch.distributed as dist + +from typing import Dict, Any +from llm_perf.utils.logger import logger +from llm_perf.utils.ps_utils import check_memory_usage +from llm_perf.utils.dist_utils import check_dist + +from accelerate import init_empty_weights + +from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader +from llm_perf.core.ckpt_loader import Mixtral_ModelLoader +from transformers import MixtralConfig +from .modeling_mixtral import MixtralForCausalLM + + +class GPUMixtralLoader(GpuCkptLoader): + def __init__( + self, + model : MixtralForCausalLM, + model_config : MixtralConfig, + ckpt_path : str = "" + ): + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__("", model, mp_size, local_rank, ckpt_path) + self.model_config = model_config + + def parallel_loader(self): + self.state_dict = {} + + model_dir = pathlib.Path(self.ckpt_path).absolute() + if not model_dir.exists() or not model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{model_dir} not exists or is not a directory") + return + + split_model_dir = model_dir.joinpath(f"TP{self.mp_size}") + if not split_model_dir.exists() or not split_model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{split_model_dir} not exists or is not a directory, please split model first.") + return + + model_loader = Mixtral_ModelLoader(split_model_dir / f"device_{self.mp_rank}") + self.state_dict = model_loader.load_weight() + + def infusion_to_model(self): + self.model.model.embed_tokens.weight = self.to_parameter(self.state_dict["model.embed_tokens.weight"]) + for i in range(self.model_config.num_hidden_layers): + self.model.model.layers[i].input_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.input_layernorm.weight"]) + + self.model.model.layers[i].self_attn.q_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]) + self.model.model.layers[i].self_attn.k_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.k_proj.weight"]) + self.model.model.layers[i].self_attn.v_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.v_proj.weight"]) + self.model.model.layers[i].self_attn.o_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.o_proj.weight"]) + + self.model.model.layers[i].post_attention_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + + self.model.model.layers[i].block_sparse_moe.gate.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.gate.weight"]) + for j in range(self.model_config.num_local_experts): + self.model.model.layers[i].block_sparse_moe.experts[j].w1.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"]) + self.model.model.layers[i].block_sparse_moe.experts[j].w2.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"]) + self.model.model.layers[i].block_sparse_moe.experts[j].w3.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"]) + + self.model.model.norm.weight = self.to_parameter(self.state_dict["model.norm.weight"]) + self.model.lm_head.weight = self.to_parameter(self.state_dict["lm_head.weight"]) + + +class GPUMixtral(nn.Module): + def __init__(self, xpu_cfg: Dict[str, Any]) -> None: + super().__init__() + + self.xpu_cfg = xpu_cfg + self.model_config = xpu_cfg["model_config"] + + self.model_name = self.model_config["model_name"] + self.model_path = self.model_config["model_path"] + self.model_network = self.model_config["network"] + + self.mixtral_config : MixtralConfig = MixtralConfig(**self.model_network) + + # dist config + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + self.transformer_model : MixtralForCausalLM = None + + + def init_inference(self): + torch.cuda.set_device(self.local_rank) + + if self.mp_size > 1: + logger.info(f"RANK: {self.local_rank} {self.mp_size} init_process_group...") + dist.init_process_group( + backend="nccl", + world_size=self.mp_size, + rank=self.local_rank + ) + check_dist() + + check_memory_usage("Begin") + + with init_empty_weights(): + self.transformer_model = MixtralForCausalLM(self.mixtral_config) + self.transformer_model.eval() + + check_memory_usage("After build model") + + self.load_weight(self.model_path) + + check_memory_usage("After load_weight") + + self.transformer_model.cuda() + + check_memory_usage("After model to device") + + self.kv_cache = self.init_kvcache(self.mixtral_config.torch_dtype) + + dist.barrier() + + def load_weight(self, ckpt_path): + p_loader = GPUMixtralLoader(self.transformer_model, self.mixtral_config, ckpt_path) + p_loader.parallel_loader() + p_loader.infusion_to_model() + + def init_kvcache(self, dtype): + max_batch_size = self.xpu_cfg["max_batch_size"] + num_layers = self.mixtral_config.num_hidden_layers + max_seq_len = self.mixtral_config.max_position_embeddings + hidden_size = self.mixtral_config.hidden_size + q_head_num = self.mixtral_config.num_attention_heads + kv_head_num = self.mixtral_config.num_key_value_heads + head_dim = hidden_size // q_head_num + + cur_device = self.transformer_model.device + + past_key_values = () + for i in range(num_layers): + kv_shape = (max_batch_size, kv_head_num // self.mp_size, max_seq_len, head_dim) + key_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + value_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device) + past_key_values += ((key_cache, value_cache),) + return past_key_values + + + def forward(self, inputs : Dict[str, torch.Tensor]): + model_outputs = self.transformer_model.forward( + **inputs, + past_key_values=self.kv_cache + ) + + # context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size] + # decode: [max_batch_size, 1] + logits = model_outputs.logits + + output_dict = { + "logits": logits + } + return output_dict \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_chatglm2.py similarity index 96% rename from byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py rename to byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_chatglm2.py index d9ac55af..c48cd963 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_chatglm2.py @@ -891,54 +891,7 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): full_attention_mask -= padding_mask.unsqueeze(-1) - 1 full_attention_mask = (full_attention_mask < 0.5).bool() full_attention_mask.unsqueeze_(1) - return full_attention_mask - - - def get_context_masks( - self, - input_ids : torch.Tensor, - padding_mask : torch.Tensor - ): - # input_ids: [1, q_len] - # padding_mask = [1, q_len] - batch_size, q_len = input_ids.shape - - # [1, q_len, q_len] - full_attention_mask = torch.ones( - 1, q_len, q_len, - device=input_ids.device - ) - full_attention_mask.tril_() - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_decode_masks( - self, - input_ids : torch.Tensor, - all_kv_len: List[int] - ): - # input_ids: [batch_size, 1] - # padding_mask: [batch_size, 1 + max_kv_len] - batch_size, q_len = input_ids.shape - max_qkv_len = q_len + max(all_kv_len) - - # [batch_size, 1, max_qkv_len] - padding_mask = [] - for i in range(batch_size): - cur_qkv_len = q_len + all_kv_len[i] - mask_per_batch = [1] * cur_qkv_len + [0] * (max_qkv_len - cur_qkv_len) - padding_mask.append(mask_per_batch) - full_attention_mask = torch.tensor( - padding_mask, - device=input_ids.device - ).unsqueeze_(1) - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - + return full_attention_mask def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape @@ -1067,13 +1020,7 @@ def forward( # if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): # full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - is_context = kwargs.get("is_context") - all_kv_len = kwargs.get("all_kv_len") - if is_context: - full_attention_mask = self.get_context_masks(input_ids, attention_mask) - else: - full_attention_mask = self.get_decode_masks(input_ids, all_kv_len) + full_attention_mask = kwargs.get("full_attention_mask") # Rotary positional embeddings diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_falcon.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_falcon.py new file mode 100644 index 00000000..895786ab --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_falcon.py @@ -0,0 +1,793 @@ +# coding=utf-8 +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import os +import math +import warnings +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from torch.nn import functional as F +import torch.distributed as dist + +from transformers.activations import get_activation +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from transformers.models.falcon.configuration_falcon import FalconConfig + + +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + + +logger = logging.get_logger(__name__) + +from transformers.models.deprecated._archive_maps import FALCON_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 + + +_CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b" +_CONFIG_FOR_DOC = "FalconConfig" + + +# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations. +# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model. +class FalconLinear(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_states = input @ self.weight.T + if self.bias is None: + return hidden_states + return hidden_states + self.bias + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon +class FalconRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + # 64 + self.dim = dim + + # 2048 + self.max_position_embeddings = max_position_embeddings + + # 10000.0 + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon +# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) +class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): + """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon +# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) +class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): + """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None].bfloat16() * arange_tensor + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + +# Copied from transformers.models.bloom.modeling_bloom.dropout_add +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + residual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +class FalconAttention(nn.Module): + def __init__(self, config: FalconConfig): + super().__init__() + + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self._use_sdpa = config._attn_implementation == "sdpa" + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + if config.rotary: + self._init_rope() + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 // self.mp_size + config.num_attention_heads // self.mp_size) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size // self.mp_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size // self.mp_size + + + self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + self.dense = FalconLinear(self.hidden_size // self.mp_size, self.hidden_size, bias=config.bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + + self.num_heads = config.num_attention_heads // self.mp_size + self.num_kv_heads = config.num_kv_heads // self.mp_size + + + + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = FalconRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = FalconLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + if self.new_decoder_architecture: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv[:, :, :, :-2] + key = qkv[:, :, :, [-2]] + value = qkv[:, :, :, [-1]] + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + elif not self.multi_query: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + else: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) + return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] + + # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + """ + Merge heads together over the last dimension + + Args: + x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // self.num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + **kwargs, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + all_q_len = kwargs.get("all_q_len") + all_kv_len = kwargs.get("all_kv_len") + + if is_context: + max_kv_len = max(all_kv_len) + else: + max_kv_len = max(all_q_len) + max(all_kv_len) + + cos, sin = self.rotary_emb(value_layer, seq_len=max_kv_len) + query_layer, key_layer = apply_rotary_pos_emb( + query_layer, key_layer, + cos, sin, + position_ids + ) + + # layer_past: [max_batch_size, q_head_num, max_seq_len, head_dim] + if is_context: + slot_id = valid_slot_ids[0] + q_len = all_q_len[0] + layer_past[0][slot_id:slot_id+1, :, :q_len, :] = key_layer + layer_past[1][slot_id:slot_id+1, :, :q_len, :] = value_layer + else: + batch_size, _, q_len, _ = key_layer.shape + max_qkv_len = q_len + max(all_kv_len) + for i, slot_id in enumerate(valid_slot_ids): + q_len = all_q_len[i] + kv_len = all_kv_len[i] + layer_past[0][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = key_layer[i, :, :, :] + layer_past[1][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = value_layer[i, :, :, :] + + cur_k_cache = layer_past[0][:, :, :max_qkv_len, :] + cur_v_cache = layer_past[1][:, :, :max_qkv_len, :] + select_slots = torch.tensor(valid_slot_ids, device=key_layer.device) + key_layer = torch.index_select(cur_k_cache, 0, select_slots) + value_layer = torch.index_select(cur_v_cache, 0, select_slots) + + present = (key_layer, value_layer) + + + if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None: + # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + + if self._use_sdpa and not output_attentions: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + ~attention_mask + ) + attention_scores = None + else: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). + attn_output = attention_scores @ value_layer + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_scores + else: + return attn_output, present + + + + + + + + +class FalconMLP(nn.Module): + def __init__(self, config: FalconConfig): + super().__init__() + + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + hidden_size = config.hidden_size + + self.dense_h_to_4h = FalconLinear( + hidden_size, + config.ffn_hidden_size // self.mp_size, + bias=config.bias + ) + self.act = get_activation(config.activation) + self.dense_4h_to_h = FalconLinear( + config.ffn_hidden_size // self.mp_size, + hidden_size, + bias=config.bias + ) + + self.hidden_dropout = config.hidden_dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.act(self.dense_h_to_4h(x)) + x = self.dense_4h_to_h(x) + return x + + +class FalconDecoderLayer(nn.Module): + def __init__(self, config: FalconConfig): + super().__init__() + + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.self_attention = FalconAttention(config) + self.mlp = FalconMLP(config) + self.hidden_dropout = config.hidden_dropout + self.config = config + + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + **kwargs, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + # MLP. + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output + if self.mp_size > 1: + dist.all_reduce(mlp_output) + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + outputs = (output,) + + return outputs # hidden_states, present, attentions + + + +class FalconPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = FalconConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["FalconDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, nn.Linear) or isinstance(module, FalconLinear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": + # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0). + if hard_check_only: + if not is_torch_greater_or_equal_than_2_0: + raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.") + + if not is_torch_greater_or_equal_than_2_0: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + return config + + +class FalconModel(FalconPreTrainedModel): + def __init__(self, config: FalconConfig): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) + + # Transformer blocks + self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + + # using full_attention_mask + # context: [1, 1, seq_len, seq_len] + # decode: [batch_size, 1, 1, past_kv_len] + attention_mask = kwargs.get("full_attention_mask") + + hidden_states = self.word_embeddings(input_ids) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=None, + use_cache=use_cache, + output_attentions=output_attentions, + alibi=None, + **kwargs + ) + hidden_states = outputs[0] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states + ) + + +class FalconForCausalLM(FalconPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: FalconConfig): + super().__init__(config) + self.transformer = FalconModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=None, + inputs_embeds=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + **kwargs + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + return CausalLMOutputWithCrossAttentions( + logits=lm_logits + ) diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py new file mode 100644 index 00000000..2429ad89 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py @@ -0,0 +1,1271 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model.""" + +import os +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.models.llama.configuration_llama import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size // self.mp_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size // self.mp_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size // self.mp_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim // self.mp_size, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size // self.mp_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + self.num_heads = self.num_heads // self.mp_size + self.num_key_value_heads = self.num_key_value_heads // self.mp_size + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + all_q_len = kwargs.get("all_q_len") + all_kv_len = kwargs.get("all_kv_len") + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # # In case static cache is used, it is an instance attribute. + # past_key_value = getattr(self, "past_key_value", past_key_value) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if is_context: + slot_id = valid_slot_ids[0] + q_len = all_q_len[0] + past_key_value[self.layer_idx][0][slot_id:slot_id+1, :, :q_len, :] = key_states + past_key_value[self.layer_idx][1][slot_id:slot_id+1, :, :q_len, :] = value_states + else: + batch_size, _, q_len, _ = key_states.shape + max_qkv_len = q_len + max(all_kv_len) + for i, slot_id in enumerate(valid_slot_ids): + q_len = all_q_len[i] + kv_len = all_kv_len[i] + past_key_value[self.layer_idx][0][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = key_states[i, :, :, :] + past_key_value[self.layer_idx][1][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = value_states[i, :, :, :] + + cur_k_cache = past_key_value[self.layer_idx][0][:, :, :max_qkv_len, :] + cur_v_cache = past_key_value[self.layer_idx][1][:, :, :max_qkv_len, :] + select_slots = torch.tensor(valid_slot_ids, device=key_states.device) + key_states = torch.index_select(cur_k_cache, 0, select_slots) + value_states = torch.index_select(cur_v_cache, 0, select_slots) + + + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = ~attention_mask + # if attention_mask is not None: + # causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0 + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + if self.mp_size > 1: + dist.all_reduce(hidden_states) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + if self.mp_size > 1: + dist.all_reduce(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + for layer in self.model.layers: + device = layer.input_layernorm.weight.device + if hasattr(self.config, "_pre_quantization_dtype"): + dtype = self.config._pre_quantization_dtype + else: + dtype = layer.self_attn.o_proj.weight.dtype + layer.self_attn.past_key_value = cache_cls( + self.config, max_batch_size, max_cache_len, device=device, dtype=dtype + ) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # embed positions + hidden_states = self.embed_tokens(input_ids) + + attention_mask = kwargs.get("full_attention_mask") + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=False, + cache_position=None, + **kwargs + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + cache_position=None, + **kwargs + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + logits=logits, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) + has_static_cache = past_key_values is not None + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_mixtral.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_mixtral.py new file mode 100644 index 00000000..4dd166d0 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_mixtral.py @@ -0,0 +1,1374 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mixtral model.""" +import os +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.distributed as dist + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from transformers.models.mixtral.configuration_mixtral import MixtralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim // self.mp_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim // self.mp_size, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.num_heads = self.num_heads // self.mp_size + self.num_key_value_heads = self.num_key_value_heads // self.mp_size + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +class MixtralSdpaAttention(MixtralAttention): + """ + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + all_q_len = kwargs.get("all_q_len") + all_kv_len = kwargs.get("all_kv_len") + + max_kv_len = max(all_kv_len) if is_context else max(all_q_len) + max(all_kv_len) + + # kv_seq_len = key_states.shape[-2] + # if past_key_value is not None: + # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=max_kv_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # past_key_value: [max_batch_size, kv_head_num, max_seq_len, head_dim] + if is_context: + slot_id = valid_slot_ids[0] + q_len = all_q_len[0] + past_key_value[self.layer_idx][0][slot_id:slot_id+1, :, :q_len, :] = key_states + past_key_value[self.layer_idx][1][slot_id:slot_id+1, :, :q_len, :] = value_states + else: + batch_size, _, q_len, _ = key_states.shape + max_qkv_len = q_len + max(all_kv_len) + for i, slot_id in enumerate(valid_slot_ids): + q_len = all_q_len[i] + kv_len = all_kv_len[i] + past_key_value[self.layer_idx][0][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = key_states[i, :, :, :] + past_key_value[self.layer_idx][1][slot_id:slot_id+1, :, kv_len:kv_len+q_len, :] = value_states[i, :, :, :] + + cur_k_cache = past_key_value[self.layer_idx][0][:, :, :max_qkv_len, :] + cur_v_cache = past_key_value[self.layer_idx][1][:, :, :max_qkv_len, :] + select_slots = torch.tensor(valid_slot_ids, device=key_states.device) + key_states = torch.index_select(cur_k_cache, 0, select_slots) + value_states = torch.index_select(cur_v_cache, 0, select_slots) + + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) + + attention_mask = kwargs.get("full_attention_mask") + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # attn_output = torch.nn.functional.scaled_dot_product_attention( + # query_states, + # key_states, + # value_states, + # attn_mask=attention_mask, + # dropout_p=self.attention_dropout if self.training else 0.0, + # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal=self.is_causal and attention_mask is None and q_len > 1, + # ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=~attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, +} + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim // self.mp_size, bias=False) + self.w2 = nn.Linear(self.ffn_dim // self.mp_size, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim // self.mp_size, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP): + def __init__(self, *args, **kwargs): + logger.warning_once( + "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40." + ) + super().__init__(*args, **kwargs) + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=False, + use_cache=False, + **kwargs, + ) + if self.mp_size > 1: + dist.all_reduce(hidden_states) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + if self.mp_size > 1: + dist.all_reduce(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MoeModelOutputWithPast]: + + hidden_states = self.embed_tokens(input_ids) + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + output_router_logits=False, + use_cache=False, + **kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + output_router_logits=False, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + logits=logits + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon_split_model.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_falcon.py similarity index 85% rename from byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon_split_model.py rename to byte_infer_perf/llm_perf/backends/GPU/model_impl/split_falcon.py index 78e1353b..659d5032 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon_split_model.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_falcon.py @@ -2,26 +2,25 @@ import sys import pathlib import argparse +from tqdm import tqdm import torch import torch.nn as nn -from typing import List +from typing import List, Optional, Union, Tuple from accelerate import init_empty_weights from transformers import FalconConfig - FILE_DIR = pathlib.Path(__file__).parent.absolute() - sys.path.insert(0, str(FILE_DIR.parent.parent.parent.parent)) -from llm_perf.backends.GPU.model_impl.falcon import FalconForCausalLM +from byte_infer_perf.llm_perf.backends.GPU.model_impl.modeling_falcon import FalconForCausalLM from llm_perf.core.ckpt_loader import Falcon_ModelLoader def to_parameter( data : torch.Tensor, - dtype : torch.dtype =None + dtype : torch.dtype = None ): if dtype is not None: data = data.to(dtype) @@ -84,34 +83,27 @@ def split( os.environ["LOCAL_RANK"] = "0" os.environ["WORLD_SIZE"] = str(args.mp_size) - model_path = pathlib.Path(args.model_path).absolute() - split_model_path = model_path / f"TP{args.mp_size}" - split_model_path.mkdir(parents=True, exist_ok=True) - - config = FalconConfig.from_pretrained(str(model_path)) + model_config = FalconConfig.from_pretrained(str(model_path)) + print(model_config) + model_loader = Falcon_ModelLoader(model_path) state_dict = model_loader.load_weight() - # for key in state_dict.keys(): - # print(key, state_dict[key].shape, state_dict[key].dtype) - - # print("") - # print("") - # print("") + # model_config.num_hidden_layers = 4 - for i in range(config.num_hidden_layers): + p_bar = tqdm(total=model_config.num_hidden_layers, desc="split model") + for i in range(model_config.num_hidden_layers): attn_qkv = f"transformer.h.{i}.self_attention.query_key_value.weight" attn_dense = f"transformer.h.{i}.self_attention.dense.weight" dense_h_to_4h = f"transformer.h.{i}.mlp.dense_h_to_4h.weight" dense_4h_to_h = f"transformer.h.{i}.mlp.dense_4h_to_h.weight" - print(i) state_dict[attn_qkv] = split( state_dict[attn_qkv], args.mp_size, dim=0, - chunks=[config.num_attention_heads, config.num_kv_heads, config.num_kv_heads] + chunks=[model_config.num_attention_heads, model_config.num_kv_heads, model_config.num_kv_heads] ) state_dict[attn_dense] = split( state_dict[attn_dense], args.mp_size, @@ -126,18 +118,24 @@ def split( dim=1 ) + p_bar.update(1) + p_bar.close() + + split_model_path = model_path / f"TP{args.mp_size}" + split_model_path.mkdir(parents=True, exist_ok=True) + with init_empty_weights(): - model = FalconForCausalLM(config) + model = FalconForCausalLM(model_config) model.eval() - for i in range(args.mp_size): - print(f"store model_{i}") - + + p_bar = tqdm(total=args.mp_size, desc="save model") + for i in range(args.mp_size): output_dir = split_model_path / f"device_{i}" output_dir.mkdir(parents=True, exist_ok=True) model.transformer.word_embeddings.weight = to_parameter(state_dict["transformer.word_embeddings.weight"]) - for j in range(config.num_hidden_layers): + for j in range(model_config.num_hidden_layers): model.transformer.h[j].self_attention.query_key_value.weight = to_parameter(state_dict[f"transformer.h.{j}.self_attention.query_key_value.weight"][i]) model.transformer.h[j].self_attention.dense.weight = to_parameter(state_dict[f"transformer.h.{j}.self_attention.dense.weight"][i]) model.transformer.h[j].mlp.dense_h_to_4h.weight = to_parameter(state_dict[f"transformer.h.{j}.mlp.dense_h_to_4h.weight"][i]) @@ -152,9 +150,6 @@ def split( model.lm_head.weight = to_parameter(state_dict["lm_head.weight"]) model.save_pretrained(str(output_dir)) - - # small_state_dict = model.state_dict() - # for key in small_state_dict.keys(): - # print(key, small_state_dict[key].shape, small_state_dict[key].dtype, small_state_dict[key].device) - + p_bar.update(1) + p_bar.close() diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_llama.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_llama.py new file mode 100644 index 00000000..596b4b4f --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_llama.py @@ -0,0 +1,146 @@ +import os +import sys +import pathlib +import argparse +from tqdm import tqdm + +import torch +import torch.nn as nn +from typing import List, Optional, Union, Tuple + +from accelerate import init_empty_weights +from transformers import LlamaConfig + +FILE_DIR = pathlib.Path(__file__).parent.absolute() + +sys.path.insert(0, str(FILE_DIR.parent.parent.parent.parent)) +from byte_infer_perf.llm_perf.backends.GPU.model_impl.modeling_llama3 import LlamaForCausalLM +from llm_perf.core.ckpt_loader import Llama_ModelLoader + + +def to_parameter( + data : torch.Tensor, + dtype : torch.dtype = None +): + if dtype is not None: + data = data.to(dtype) + return nn.Parameter(data, requires_grad=False) + + +def split( + src : torch.Tensor, + mp_size : int, + dim : int, + chunks : List [int]=[] +): + if len(chunks) == 0: + split_arg = src.shape[dim] // mp_size + output_tensors = torch.split(src, split_arg, dim=dim) + else: + # for example + # chunks = [32, 2, 2], sum_chunks = 36, src.shape[dim] = (32 + 2 + 2) * 128, other_dim = 128 + # mp_size = 8 + # new_chunks = [4, 1, 1] + sum_chunks = sum(chunks) + other_dim_size = src.shape[dim] // sum_chunks + + split_arg = [i * other_dim_size for i in chunks] + split_tensors = torch.split(src, split_arg, dim=dim) + + output_split = [] + for i, tensor in enumerate(split_tensors): + if mp_size > chunks[i]: + tensor_shape = tensor.size()[:dim] + (chunks[i], 1, other_dim_size) + tensor.size()[dim+1:] + new_tensor_shape = tensor.size()[:dim] + (chunks[i], mp_size // chunks[i], other_dim_size) + tensor.size()[dim+1:] + output_tensor_shape = tensor.size()[:dim] + (mp_size * other_dim_size,) + tensor.size()[dim+1:] + + tensor = tensor.view(tensor_shape) + tensor = tensor.expand(*new_tensor_shape) + tensor = tensor.contiguous() + tensor = tensor.view(output_tensor_shape) + + cur_split = torch.split(tensor, tensor.shape[dim] // mp_size, dim=dim) + output_split.append(cur_split) + + output_tensors = [] + for i in range(mp_size): + temp_tensors = [output_split[j][i] for j in range(len(chunks))] + tp_tensors = torch.concat(temp_tensors, dim=dim) + output_tensors.append(tp_tensors) + + output_tensors = [tensor.contiguous() for tensor in output_tensors] + + return output_tensors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--mp_size", type=int, default=8, choices=[2, 4, 8]) + args = parser.parse_args() + + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = str(args.mp_size) + + model_path = pathlib.Path(args.model_path).absolute() + model_config : LlamaConfig = LlamaConfig.from_pretrained(str(model_path)) + print(model_config) + + model_loader = Llama_ModelLoader(model_path) + state_dict = model_loader.load_weight() + + # model_config.num_hidden_layers = 4 + + p_bar = tqdm(total=model_config.num_hidden_layers, desc="split model") + for i in range(model_config.num_hidden_layers): + q = f"model.layers.{i}.self_attn.q_proj.weight" + k = f"model.layers.{i}.self_attn.k_proj.weight" + v = f"model.layers.{i}.self_attn.v_proj.weight" + o = f"model.layers.{i}.self_attn.o_proj.weight" + + state_dict[q] = split(state_dict[q], args.mp_size, 0) + state_dict[k] = split(state_dict[k], args.mp_size, 0) + state_dict[v] = split(state_dict[v], args.mp_size, 0) + state_dict[o] = split(state_dict[o], args.mp_size, 1) + + gate = f"model.layers.{i}.mlp.gate_proj.weight" + up = f"model.layers.{i}.mlp.up_proj.weight" + down = f"model.layers.{i}.mlp.down_proj.weight" + + state_dict[gate] = split(state_dict[gate], args.mp_size, 0) + state_dict[up] = split(state_dict[up], args.mp_size, 0) + state_dict[down] = split(state_dict[down], args.mp_size, 1) + + p_bar.update(1) + p_bar.close() + + split_model_path = model_path / f"TP{args.mp_size}" + split_model_path.mkdir(parents=True, exist_ok=True) + + with init_empty_weights(): + model = LlamaForCausalLM(model_config).to(model_config.torch_dtype).eval() + + p_bar = tqdm(total=args.mp_size, desc="save model") + for rank in range(args.mp_size): + output_dir = split_model_path / f"device_{rank}" + output_dir.mkdir(parents=True, exist_ok=True) + + model.model.embed_tokens.weight = to_parameter(state_dict["model.embed_tokens.weight"]) + for i in range(model_config.num_hidden_layers): + model.model.layers[i].self_attn.q_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.q_proj.weight"][rank]) + model.model.layers[i].self_attn.k_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.k_proj.weight"][rank]) + model.model.layers[i].self_attn.v_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.v_proj.weight"][rank]) + model.model.layers[i].self_attn.o_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.o_proj.weight"][rank]) + + model.model.layers[i].mlp.gate_proj.weight = to_parameter(state_dict[f"model.layers.{i}.mlp.gate_proj.weight"][rank]) + model.model.layers[i].mlp.up_proj.weight = to_parameter(state_dict[f"model.layers.{i}.mlp.up_proj.weight"][rank]) + model.model.layers[i].mlp.down_proj.weight = to_parameter(state_dict[f"model.layers.{i}.mlp.down_proj.weight"][rank]) + + model.model.layers[i].input_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.input_layernorm.weight"]) + model.model.layers[i].post_attention_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + model.model.norm.weight = to_parameter(state_dict["model.norm.weight"]) + model.lm_head.weight = to_parameter(state_dict["lm_head.weight"]) + + model.save_pretrained(str(output_dir)) + p_bar.update(1) + p_bar.close() diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_mixtral.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_mixtral.py new file mode 100644 index 00000000..7b775be6 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/split_mixtral.py @@ -0,0 +1,148 @@ +import os +import sys +import pathlib +import argparse +from tqdm import tqdm + +import torch +import torch.nn as nn +from typing import List, Optional, Union, Tuple + +from accelerate import init_empty_weights +from transformers import MixtralConfig + +FILE_DIR = pathlib.Path(__file__).parent.absolute() + +sys.path.insert(0, str(FILE_DIR.parent.parent.parent.parent)) +from byte_infer_perf.llm_perf.backends.GPU.model_impl.modeling_mixtral import MixtralForCausalLM +from llm_perf.core.ckpt_loader import Mixtral_ModelLoader + + +def to_parameter( + data : torch.Tensor, + dtype : torch.dtype = None +): + if dtype is not None: + data = data.to(dtype) + return nn.Parameter(data, requires_grad=False) + + +def split( + src : torch.Tensor, + mp_size : int, + dim : int, + chunks : List [int]=[] +): + if len(chunks) == 0: + split_arg = src.shape[dim] // mp_size + output_tensors = torch.split(src, split_arg, dim=dim) + else: + # for example + # chunks = [32, 2, 2], sum_chunks = 36, src.shape[dim] = (32 + 2 + 2) * 128, other_dim = 128 + # mp_size = 8 + # new_chunks = [4, 1, 1] + sum_chunks = sum(chunks) + other_dim_size = src.shape[dim] // sum_chunks + + split_arg = [i * other_dim_size for i in chunks] + split_tensors = torch.split(src, split_arg, dim=dim) + + output_split = [] + for i, tensor in enumerate(split_tensors): + if mp_size > chunks[i]: + tensor_shape = tensor.size()[:dim] + (chunks[i], 1, other_dim_size) + tensor.size()[dim+1:] + new_tensor_shape = tensor.size()[:dim] + (chunks[i], mp_size // chunks[i], other_dim_size) + tensor.size()[dim+1:] + output_tensor_shape = tensor.size()[:dim] + (mp_size * other_dim_size,) + tensor.size()[dim+1:] + + tensor = tensor.view(tensor_shape) + tensor = tensor.expand(*new_tensor_shape) + tensor = tensor.contiguous() + tensor = tensor.view(output_tensor_shape) + + cur_split = torch.split(tensor, tensor.shape[dim] // mp_size, dim=dim) + output_split.append(cur_split) + + output_tensors = [] + for i in range(mp_size): + temp_tensors = [output_split[j][i] for j in range(len(chunks))] + tp_tensors = torch.concat(temp_tensors, dim=dim) + output_tensors.append(tp_tensors) + + output_tensors = [tensor.contiguous() for tensor in output_tensors] + + return output_tensors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--mp_size", type=int, default=8, choices=[2, 4, 8]) + args = parser.parse_args() + + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = str(args.mp_size) + + model_path = pathlib.Path(args.model_path).absolute() + model_config : MixtralConfig = MixtralConfig.from_pretrained(str(model_path)) + print(model_config) + + model_loader = Mixtral_ModelLoader(model_path) + state_dict = model_loader.load_weight() + + # model_config.num_hidden_layers = 4 + + p_bar = tqdm(total=model_config.num_hidden_layers, desc="split model") + for i in range(model_config.num_hidden_layers): + q = f"model.layers.{i}.self_attn.q_proj.weight" + k = f"model.layers.{i}.self_attn.k_proj.weight" + v = f"model.layers.{i}.self_attn.v_proj.weight" + o = f"model.layers.{i}.self_attn.o_proj.weight" + + state_dict[q] = split(state_dict[q], args.mp_size, 0) + state_dict[k] = split(state_dict[k], args.mp_size, 0) + state_dict[v] = split(state_dict[v], args.mp_size, 0) + state_dict[o] = split(state_dict[o], args.mp_size, 1) + + for j in range(model_config.num_local_experts): + w1 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight" + w2 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight" + w3 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight" + + state_dict[w1] = split(state_dict[w1], args.mp_size, 0) + state_dict[w2] = split(state_dict[w2], args.mp_size, 1) + state_dict[w3] = split(state_dict[w3], args.mp_size, 0) + + p_bar.update(1) + p_bar.close() + + split_model_path = model_path / f"TP{args.mp_size}" + split_model_path.mkdir(parents=True, exist_ok=True) + + with init_empty_weights(): + model = MixtralForCausalLM(model_config) + model.eval() + + p_bar = tqdm(total=args.mp_size, desc="save model") + for rank in range(args.mp_size): + output_dir = split_model_path / f"device_{rank}" + output_dir.mkdir(parents=True, exist_ok=True) + + model.model.embed_tokens.weight = to_parameter(state_dict["model.embed_tokens.weight"]) + for i in range(model_config.num_hidden_layers): + model.model.layers[i].self_attn.q_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.q_proj.weight"][rank]) + model.model.layers[i].self_attn.k_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.k_proj.weight"][rank]) + model.model.layers[i].self_attn.v_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.v_proj.weight"][rank]) + model.model.layers[i].self_attn.o_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.o_proj.weight"][rank]) + model.model.layers[i].block_sparse_moe.gate.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.gate.weight"]) + for j in range(model_config.num_local_experts): + model.model.layers[i].block_sparse_moe.experts[j].w1.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"][rank]) + model.model.layers[i].block_sparse_moe.experts[j].w2.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"][rank]) + model.model.layers[i].block_sparse_moe.experts[j].w3.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"][rank]) + model.model.layers[i].input_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.input_layernorm.weight"]) + model.model.layers[i].post_attention_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + model.model.norm.weight = to_parameter(state_dict["model.norm.weight"]) + model.lm_head.weight = to_parameter(state_dict["lm_head.weight"]) + + model.save_pretrained(str(output_dir)) + p_bar.update(1) + p_bar.close() diff --git a/byte_infer_perf/llm_perf/benchmark/bench.py b/byte_infer_perf/llm_perf/benchmark/bench.py index 3f406912..e73d2cb4 100644 --- a/byte_infer_perf/llm_perf/benchmark/bench.py +++ b/byte_infer_perf/llm_perf/benchmark/bench.py @@ -64,8 +64,8 @@ def bench_accuracy(stub, workload: Dict[str, Any], result_queue: mp.Queue): stub, index=0, prompt=question, - min_new_tokens=1, - max_new_tokens=512, + min_new_tokens=workload["min_new_tokens"], + max_new_tokens=workload["max_new_tokens"], top_p=0, top_k=1, # use greedy search for accuracy bench get_input_logits=1, @@ -102,7 +102,6 @@ def bench_performance( ): result_queue.put("@start") - accum_time = 0 perf_time: int = workload["perf_time"] * int(1e9) @@ -113,8 +112,8 @@ def bench_performance( st = time.perf_counter_ns() first_token_latency = 0 - min_new_tokens = workload["min_new_tokens"] - max_new_tokens = workload["max_new_tokens"] + min_new_tokens = workload["output_tokens"] + max_new_tokens = workload["output_tokens"] output_messages: str = "" wait_time = [] @@ -186,9 +185,9 @@ def benchmark( report_type: ReportType, input_tokens: int, result_queue: mp.Queue, - args, + host, port ): - with grpc.insecure_channel(f"{args.host}:{args.port}") as channel: + with grpc.insecure_channel(f"{host}:{port}") as channel: stub = server_pb2_grpc.InferenceStub(channel) logger.debug(f"{report_type.name} bench_{index} start") diff --git a/byte_infer_perf/llm_perf/core/ckpt_loader.py b/byte_infer_perf/llm_perf/core/ckpt_loader.py index 02cb98bd..774c306d 100644 --- a/byte_infer_perf/llm_perf/core/ckpt_loader.py +++ b/byte_infer_perf/llm_perf/core/ckpt_loader.py @@ -297,17 +297,82 @@ def load_weight(self): return self.weight_dict +from transformers import LlamaConfig +class Llama_ModelLoader(ModelLoader): + def __init__(self, model_dir : pathlib.Path): + model_config = LlamaConfig.from_pretrained(model_dir) + weight_index_config = {} + for child in model_dir.iterdir(): + if child.name.endswith(".index.json"): + with open(child, "r") as f: + weight_index_config = json.load(f) + break + + self.layer_num = model_config.num_hidden_layers + + super().__init__( + model_dir, + weight_index_config["metadata"]["total_size"], + weight_index_config["weight_map"] + ) + + def load_weight(self): + self.loaded_bytes = 0 + self.weight_dict = {} + + self.load_tensor("model.embed_tokens.weight") + for i in range(self.layer_num): + self.load_tensor(f"model.layers.{i}.input_layernorm.weight") + + self.load_tensor(f"model.layers.{i}.self_attn.q_proj.weight") + self.load_tensor(f"model.layers.{i}.self_attn.k_proj.weight") + self.load_tensor(f"model.layers.{i}.self_attn.v_proj.weight") + self.load_tensor(f"model.layers.{i}.self_attn.o_proj.weight") + + self.load_tensor(f"model.layers.{i}.post_attention_layernorm.weight") + self.load_tensor(f"model.layers.{i}.mlp.gate_proj.weight") + self.load_tensor(f"model.layers.{i}.mlp.up_proj.weight") + self.load_tensor(f"model.layers.{i}.mlp.down_proj.weight") + self.load_tensor("model.norm.weight") + self.load_tensor("lm_head.weight") -class Mixtral8x22B_ModelLoader(ModelLoader): + weight_bytes = 0 + for tensor_name in self.weight_dict: + tensor = self.weight_dict[tensor_name] + weight_bytes += tensor.numel() * tensor.element_size() + + logger.info(f"total_size: {self.total_size}, loaded_bytes: {self.loaded_bytes}, weight_bytes: {weight_bytes}") + assert self.loaded_bytes == self.total_size + assert weight_bytes == self.total_size + + return self.weight_dict + + + + + + +from transformers import MixtralConfig + +class Mixtral_ModelLoader(ModelLoader): def __init__( self, - model_dir : pathlib.Path, - model_config: Dict, - weight_index_config: Dict, + model_dir : pathlib.Path ) -> None: + model_config = MixtralConfig.from_pretrained(model_dir) + weight_index_config = {} + for child in model_dir.iterdir(): + if child.name.endswith(".index.json"): + with open(child, "r") as f: + weight_index_config = json.load(f) + break + + self.layer_num = model_config.num_hidden_layers + self.expert_num = model_config.num_local_experts + # parent class super().__init__( model_dir, @@ -315,11 +380,6 @@ def __init__( weight_index_config["weight_map"] ) - # model config - self.layer_num = 56 - self.expert_num = 8 - - def load_weight(self): self.loaded_bytes = 0 diff --git a/byte_infer_perf/llm_perf/launch.py b/byte_infer_perf/llm_perf/launch.py index fe06670e..94834eed 100644 --- a/byte_infer_perf/llm_perf/launch.py +++ b/byte_infer_perf/llm_perf/launch.py @@ -33,67 +33,14 @@ from llm_perf.utils.reporter import Reporter, ReportType -def load_workload(task: str) -> Dict[str, Any]: - """ - Return a list of dictionary with model Configuration - - Args: List[str] - - Returns: List[dic] - """ - modules_dir = LLM_PERF_ROOT.joinpath("workloads") - - workload_dict = None - for filepath in modules_dir.iterdir(): - if filepath.suffix == ".json" and filepath.stem == task: - with open(filepath) as file: - workload_dict = json.load(file) - break - if workload_dict is None: - logger.error(f"Task name: {task} was not found, please check your task name") - raise RuntimeError("invalid parameter") - return workload_dict - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--hardware_type", type=str, - default="GPU", - help="The backend going to be evaluted, refs to backends/", - ) - parser.add_argument( - "--task", type=str, - default="chatglm2-torch-fp16-6b", - help="The task going to be evaluted, refs to workloads/", - ) - - parser.add_argument( - "--host", type=str, - default="127.0.0.1", - help="Host for the gRPC server" - ) - parser.add_argument( - "--port", type=int, - default=51000, - help="port of the server") - - args = parser.parse_args() - return args - - - class PerfEngine: - def __init__(self) -> None: + def __init__(self, hardware, task, host, port) -> None: super().__init__() - self.args = get_args() - - self.backend_type = self.args.hardware_type - self.task = self.args.task - self.host = self.args.host - self.port = self.args.port + self.backend_type = hardware + self.task = task + self.host = host + self.port = port self.result_queue = mp.Queue() self.jobs: List[mp.Process] = [] @@ -105,55 +52,25 @@ def __del__(self): def start_engine(self) -> None: - """ - Byte MlPerf will create an virtual env for each backend to avoid dependance conflict - """ - loglevel = os.environ.get("LOG_LEVEL", "info") - setup_logger(loglevel) - # load workload workload = load_workload(self.task) - # model name model_name = workload["model"] - # test items - test_accuracy = True if "test_accuracy" in workload and bool(workload["test_accuracy"]) else False - test_perf = True if "test_perf" in workload and bool(workload["test_perf"]) else False - - # test config - test_dataset = workload["dataset"] - test_perf_time = workload["perf_time"] - - test_tp_sizes = workload["tp_sizes"] - test_batch_sizes = workload["batch_sizes"] - test_input_tokens = workload["input_tokens"] - - # generation config - min_new_tokens = workload["min_new_tokens"] - max_new_tokens = workload["max_new_tokens"] - - - # download model parameter and golden outputs - weight_dir = LLM_PERF_ROOT.joinpath("model_zoo", "sota", model_name) - refer_dir = LLM_PERF_ROOT.joinpath("reports", "base", model_name) - if not weight_dir.exists() or not refer_dir.exists(): - download_script = LLM_PERF_ROOT.joinpath("prepare_model.sh") - subprocess.run( - [ - "bash", download_script, - model_name, - str(test_accuracy) - ] - ) - + min_tp_size = int(workload["min_tp_size"]) + test_accuracy = bool(workload["test_accuracy"]) if "test_accuracy" in workload else False + test_perf = bool(workload["test_perf"]) if "test_perf" in workload else False if not any([test_perf, test_accuracy]): logger.info(f"End of the llm_perf, enable at least one test item") return - min_tp_size = test_tp_sizes[0] - min_batch_size = test_batch_sizes[0] - min_input_tokens = test_input_tokens[0] + + # download model parameter and golden outputs + download_cmd = f"python3 llm_perf/prepare_model.py --task {self.task} --download_model" + if test_accuracy: + download_cmd += " --download_baseline" + subprocess.run(download_cmd, shell=True) + # create and start reporter self.reporter = Reporter( @@ -161,32 +78,39 @@ def start_engine(self) -> None: backend=self.backend_type, tp_size=min_tp_size, - batch_size=min_batch_size, - input_tokens=min_input_tokens, - min_new_tokens=min_new_tokens, - max_new_tokens=max_new_tokens, + batch_size=1, + input_tokens=1024, + min_new_tokens=1, + max_new_tokens=512, test_perf=test_perf, test_accuracy=test_accuracy, ) self.reporter.start() - # 1. Accuracy Test: default batch_size & tp_size are both 1 if test_accuracy: + accuracy_config = workload["accuracy_config"] + logger.info("start test accuracy.") logger.info(f"using tp_size={min_tp_size}") - logger.info(f"using batch_size={min_batch_size}") - logger.info(f"using input_tokens={min_input_tokens}") + logger.info(f"using batch_size=1") + self.run_perf( - workload, - min_tp_size, - min_batch_size, - min_input_tokens, + accuracy_config, + min_tp_size, 1, 1024, ReportType.ACCURACY ) - # 2. Performance Test if test_perf: + perf_config = workload["perf_config"] + + test_tp_sizes = [] + for tp_size in perf_config["tp_sizes"]: + if tp_size >= min_tp_size: + test_tp_sizes.append(tp_size) + test_batch_sizes = perf_config["batch_sizes"] + test_input_tokens = perf_config["input_tokens"] + logger.info("start test performance.") logger.info(f"tp_sizes list: {test_tp_sizes}") logger.info(f"batch_sizes list: {test_batch_sizes}") @@ -199,10 +123,8 @@ def start_engine(self) -> None: print(f"using tp_size={tp_size}, batch_size={batch_size}, input_tokens={input_tokens}") print("*"*150) self.run_perf( - workload, - tp_size, - batch_size, - input_tokens, + perf_config, + tp_size, batch_size, input_tokens, ReportType.PERFORMANCE, ) print("\n\n\n") @@ -315,13 +237,86 @@ def start_benchmark( report_type, input_tokens, self.result_queue, - self.args, + self.host, self.port ), ) self.jobs.append(p) p.start() +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--hardware_type", type=str, + default="GPU", + help="The backend going to be evaluted, refs to backends/", + ) + parser.add_argument( + "--task", type=str, + default="chatglm2-torch-fp16-6b", + help="The task going to be evaluted, refs to workloads/", + ) + + parser.add_argument( + "--host", type=str, + default="127.0.0.1", + help="Host for the gRPC server" + ) + parser.add_argument( + "--port", type=int, + default=51000, + help="port of the server") + + parser.add_argument( + "--log_level", type=str, + default=os.environ.get("LOG_LEVEL", "info"), + help="log level" + ) + + args = parser.parse_args() + return args + + + +def load_workload(task: str) -> Dict[str, Any]: + """ + Return a list of dictionary with model Configuration + + Args: List[str] + + Returns: List[dic] + """ + modules_dir = LLM_PERF_ROOT.joinpath("workloads") + + workload_dict = None + for filepath in modules_dir.iterdir(): + if filepath.suffix == ".json" and filepath.stem == task: + with open(filepath) as file: + workload_dict = json.load(file) + break + if workload_dict is None: + logger.error(f"Task name: {task} was not found, please check your task name") + exit(-1) + return workload_dict + + + + if __name__ == "__main__": - instance = PerfEngine() + args = parse_args() + + hardware = args.hardware_type + task = args.task + host = args.host + port = args.port + + setup_logger(args.log_level) + + logger.info(f"hardware: {hardware}") + logger.info(f"task: {task}") + logger.info(f"host: {host}") + logger.info(f"port: {port}") + + instance = PerfEngine(hardware, task, host, port) instance.start_engine() diff --git a/byte_infer_perf/llm_perf/model_zoo/chatglm2-6b.json b/byte_infer_perf/llm_perf/model_zoo/chatglm2-torch-fp16-6b.json similarity index 95% rename from byte_infer_perf/llm_perf/model_zoo/chatglm2-6b.json rename to byte_infer_perf/llm_perf/model_zoo/chatglm2-torch-fp16-6b.json index 19a64e22..f7cc38c0 100644 --- a/byte_infer_perf/llm_perf/model_zoo/chatglm2-6b.json +++ b/byte_infer_perf/llm_perf/model_zoo/chatglm2-torch-fp16-6b.json @@ -3,8 +3,8 @@ "model_path": "llm_perf/model_zoo/sota/chatglm2-6b", "model_interface": "ChatGLMForConditionalGeneration", "tokenizer": { - "path": "llm_perf/model_zoo/sota/chatglm2-6b", - "add_sep_token": false + "path": "llm_perf/model_zoo/sota/chatglm2-6b", + "support_chn": true }, "network": { "_name_or_path": "THUDM/chatglm2-6b", diff --git a/byte_infer_perf/llm_perf/model_zoo/falcon-180b.json b/byte_infer_perf/llm_perf/model_zoo/falcon-torch-bf16-180b.json similarity index 100% rename from byte_infer_perf/llm_perf/model_zoo/falcon-180b.json rename to byte_infer_perf/llm_perf/model_zoo/falcon-torch-bf16-180b.json diff --git a/byte_infer_perf/llm_perf/model_zoo/llama3-torch-bf16-70b.json b/byte_infer_perf/llm_perf/model_zoo/llama3-torch-bf16-70b.json new file mode 100644 index 00000000..f5c7fb41 --- /dev/null +++ b/byte_infer_perf/llm_perf/model_zoo/llama3-torch-bf16-70b.json @@ -0,0 +1,37 @@ +{ + "model_name": "llama3", + "model_path": "llm_perf/model_zoo/sota/llama3-70b", + "model_interface": "FalconForCausalLM", + "tokenizer": { + "path": "llm_perf/model_zoo/sota/llama3-70b", + "support_chn": true, + "apply_chat_template": true + }, + "network": { + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128009, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": true, + "vocab_size": 128256 + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/model_zoo/mixtral-torch-bf16-8x22b.json b/byte_infer_perf/llm_perf/model_zoo/mixtral-torch-bf16-8x22b.json new file mode 100644 index 00000000..9ff68d1c --- /dev/null +++ b/byte_infer_perf/llm_perf/model_zoo/mixtral-torch-bf16-8x22b.json @@ -0,0 +1,37 @@ +{ + "model_name": "mixtral", + "model_path": "llm_perf/model_zoo/sota/mixtral-8x22b", + "model_interface": "MixtralForCausalLM", + "tokenizer": { + "path": "llm_perf/model_zoo/sota/mixtral-8x22b" + }, + "network": { + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0", + "use_cache": true, + "vocab_size": 32000 + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/prepare_model.py b/byte_infer_perf/llm_perf/prepare_model.py new file mode 100644 index 00000000..00e21f13 --- /dev/null +++ b/byte_infer_perf/llm_perf/prepare_model.py @@ -0,0 +1,82 @@ +import os +import sys +import pathlib +import argparse +import subprocess + +# ${prj_root}/ +BYTE_MLPERF_ROOT = pathlib.Path(__file__).parents[1].absolute() +LLM_PERF_ROOT = BYTE_MLPERF_ROOT.joinpath("llm_perf") + +task_map = { + "chatglm2-torch-fp16-6b": ("chatglm2-6b", "THUDM/chatglm2-6b"), + "llama3-torch-bf16-70b": ("llama3-70b", "shenzhi-wang/Llama3-70B-Chinese-Chat"), + "falcon-torch-bf16-180b": ("falcon-180b", "tiiuae/falcon-180B"), + "mixtral-torch-bf16-8x22b": ("mixtral-8x22b", "mistralai/Mixtral-8x22B-v0.1"), +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="chatglm2-torch-fp16-6b") + parser.add_argument("--download_model", action="store_true") + parser.add_argument("--download_baseline", action="store_true") + args = parser.parse_args() + + os.chdir(LLM_PERF_ROOT) + + task_name = args.task + if task_name not in task_map: + print(f"task {task_name} not found, please check your task name") + sys.exit(-1) + + model_name = task_map[task_name][0] + model_repo_name = task_map[task_name][1] + + download_path = LLM_PERF_ROOT.joinpath("download") + download_path.mkdir(parents=True, exist_ok=True) + + if args.download_model: + sota_model_path = LLM_PERF_ROOT.joinpath("model_zoo", "sota") + sota_model_path.mkdir(parents=True, exist_ok=True) + + model_path = sota_model_path.joinpath(model_name) + if model_path.exists(): + print(f"model {model_name} already exists, skip downloading model.") + else: + print(f"downloading model {model_name}") + subprocess.run( + f"huggingface-cli download --local-dir {model_path} {model_repo_name}", + shell=True, check=True + ) + + if args.download_baseline: + gpu_baseline_path = LLM_PERF_ROOT.joinpath("reports", "base") + gpu_baseline_path.mkdir(parents=True, exist_ok=True) + + tar_file_name = f"reports_gpu_{task_name}.tar.gz" + src_path = f"https://lf-bytemlperf.17mh.cn/obj/bytemlperf-zoo/llm/{tar_file_name}" + dst_path = download_path.joinpath(tar_file_name) + + if dst_path.exists(): + print(f"baseline {model_name} already exists, skip downloading baseline.") + else: + print(f"downloading baseline {model_name}") + subprocess.run( + f"wget -O {dst_path} {src_path}", + shell=True, check=True + ) + + base_path = gpu_baseline_path.joinpath(task_name) + if base_path.exists(): + print(f"baseline {model_name} already exists, skip extracting baseline.") + else: + print(f"extracting baseline {model_name}") + subprocess.run( + f"tar -xzvf {dst_path} -C {gpu_baseline_path}", + shell=True, check=True + ) + + + + + diff --git a/byte_infer_perf/llm_perf/prepare_model.sh b/byte_infer_perf/llm_perf/prepare_model.sh deleted file mode 100644 index c74c48ad..00000000 --- a/byte_infer_perf/llm_perf/prepare_model.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -SHELL_FOLDER=$(cd "$(dirname "$0")"; pwd) -cd $SHELL_FOLDER/.. - -echo "******************* Downloading Model and Logits.... *******************" - -mkdir -p llm_perf/download - -SOTA_MODEL_CKPT="llm_perf/model_zoo/sota" -GPU_REPORT_BASELINE="llm_perf/reports/base" - -mkdir -p $SOTA_MODEL_CKPT -mkdir -p $GPU_REPORT_BASELINE - -MODEL=$1 -ENABLE_ACC=$2 - -# supported model: -# * chatglm2-torch-fp16-6b -# * chinese-llama2-torch-fp16-13b -# * mixtral-torch-fp16-8x7b -if [ $MODEL == "chatglm2-torch-fp16-6b" ] || - [ $MODEL == "chinese-llama2-torch-fp16-13b" ] || - [ $MODEL == "mixtral-torch-fp16-8x7b" ]; then - if [ -d "$SOTA_MODEL_CKPT/$MODEL" ]; then - echo "already exist model, skip download" - else - wget -O llm_perf/download/$MODEL.tar.gz https://lf-bytemlperf.17mh.cn/obj/bytemlperf-zoo/llm/$MODEL.tar.gz - tar xf llm_perf/download/$MODEL.tar.gz -C $SOTA_MODEL_CKPT - fi - if [ $ENABLE_ACC == "True" ]; then - if [ -d "$GPU_REPORT_BASELINE/$MODEL" ]; then - echo "already exist logits, skip download" - else - wget -O llm_perf/download/reports_gpu_$MODEL.tar.gz https://lf-bytemlperf.17mh.cn/obj/bytemlperf-zoo/llm/reports_gpu_$MODEL.tar.gz - tar xf llm_perf/download/reports_gpu_$MODEL.tar.gz -C $GPU_REPORT_BASELINE - fi - fi -else - echo "Unsupported model!" - exit -1 -fi - -echo "Extract Done." diff --git a/byte_infer_perf/llm_perf/server/endpoint.py b/byte_infer_perf/llm_perf/server/endpoint.py index 0620ac91..c33d078a 100644 --- a/byte_infer_perf/llm_perf/server/endpoint.py +++ b/byte_infer_perf/llm_perf/server/endpoint.py @@ -17,26 +17,35 @@ def __init__(self, xpu_cfg) -> None: super().__init__() self.xpu_cfg = xpu_cfg - - model_config = xpu_cfg["model_config"] hardware_type = xpu_cfg["hardware_type"] + model_config = xpu_cfg["model_config"] # load tokenizer - tokenizer_path = model_config["tokenizer"]["path"] - self.add_sep_token = model_config["tokenizer"].get("add_sep_token", False) - self.tokenizer : PreTrainedTokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=tokenizer_path, - local_files_only=True, - trust_remote_code=True - ) + try: + tokenizer_config = model_config["tokenizer"] + tokenizer_path = tokenizer_config["path"] + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=tokenizer_path, + local_files_only=True, + trust_remote_code=True + ) + self.support_chn = tokenizer_config.get("support_chn", False) + self.apply_chat_template = tokenizer_config.get("apply_chat_template", False) + except Exception as e: + logger.error(f"load tokenizer error: {e}") + sys.exit(-1) + logger.info(f"load tokenizer: {tokenizer_path}") logger.info("*"*50) logger.info(f"bos_token_id: {self.tokenizer.bos_token_id}") logger.info(f"eos_token_id: {self.tokenizer.eos_token_id}") + logger.info(f"unk_token_id: {self.tokenizer.unk_token_id}") logger.info(f"pad_token_id: {self.tokenizer.pad_token_id}") - logger.info(f"sep_token_id: {self.tokenizer.sep_token_id}") logger.info("*"*50) + xpu_cfg["bos_token_id"] = self.tokenizer.bos_token_id + xpu_cfg["eos_token_id"] = self.tokenizer.eos_token_id + xpu_cfg["unk_token_id"] = self.tokenizer.unk_token_id xpu_cfg["pad_token_id"] = self.tokenizer.pad_token_id # import setup according to hardware_type @@ -58,7 +67,11 @@ def __del__(self): def warmup(self, max_batch_size): - prompt = "中国的首都是哪里?" + if self.support_chn: + prompt = "7年前,我的年龄是我的儿子的6倍,我的儿子今年12岁,我今年多少岁?" + else: + prompt = "7 years ago, I was 6 times older than my son. My son is 12 years old now. How old am I now?" + generate_config = { "min_new_tokens": 1, "max_new_tokens": 512, @@ -90,9 +103,15 @@ async def _multiple_warmup(): async def prepare_request( self, prompt: str, generate_config: Dict[str, Any] ) -> GenerateRequest: - input_ids = self.tokenizer.encode(prompt) - if self.add_sep_token: - input_ids.append(self.tokenizer.sep_token_id) + if not self.apply_chat_template: + input_ids = self.tokenizer.encode(prompt) + else: + input_ids = self.tokenizer.apply_chat_template( + [ + {"role": "user", "content": prompt} + ], + add_generation_prompt=True + ) # create generate config config = GenerateConfig( @@ -138,9 +157,10 @@ async def streaming_inference( } if result is not None: + text = self.tokenizer.decode([result.token_id], skip_special_tokens=True, clean_up_tokenization_spaces=True) infer_outputs["choice"].update( { - "message": self.tokenizer.decode(result.token_id), + "message": text, "wait_time": result.wait_time, "model_time": result.model_time, "post_process_time": result.post_process_time diff --git a/byte_infer_perf/llm_perf/server/launch_server.py b/byte_infer_perf/llm_perf/server/launch_server.py index 938d014b..76387f16 100644 --- a/byte_infer_perf/llm_perf/server/launch_server.py +++ b/byte_infer_perf/llm_perf/server/launch_server.py @@ -53,9 +53,6 @@ async def StreamingInference( outputs={k: serialize_value(v) for k, v in result.items()}, ) - - - async def serve(port, generator: LLMPerfEndpoint) -> None: server = grpc.aio.server( migration_thread_pool=futures.ThreadPoolExecutor( @@ -113,6 +110,7 @@ def main(): # create xpu config xpu_cfg = {} + xpu_cfg["hardware_type"] = args.hardware_type xpu_cfg["tp_size"] = args.tp_size xpu_cfg["max_batch_size"] = args.max_batch_size diff --git a/byte_infer_perf/llm_perf/utils/ps_utils.py b/byte_infer_perf/llm_perf/utils/ps_utils.py index c2abbe3d..592cf006 100644 --- a/byte_infer_perf/llm_perf/utils/ps_utils.py +++ b/byte_infer_perf/llm_perf/utils/ps_utils.py @@ -7,6 +7,11 @@ from llm_perf.utils.logger import logger def check_memory_usage(tag): + + # dist config + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports gc.collect() @@ -21,8 +26,10 @@ def check_memory_usage(tag): else: pass - local_rank = int(os.environ.get("LOCAL_RANK", "0")) msg = f"<<{tag}>> CPU VM State: Used = {used_GB} GB, Percent = {vm_stats.percent}% | "\ f"DEV MEM State(Rank{local_rank}): Used = {dev_mem_allocated} GB, Reserved = {dev_mem_reserved} GB" - logger.info(msg) + + if local_rank == 0: + print(msg) + # logger.info(msg) diff --git a/byte_infer_perf/llm_perf/workloads/chatglm2-6b.json b/byte_infer_perf/llm_perf/workloads/chatglm2-6b.json deleted file mode 100644 index d02d7d92..00000000 --- a/byte_infer_perf/llm_perf/workloads/chatglm2-6b.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "model": "chatglm2-6b", - "test_accuracy": true, - "test_perf": true, - "min_new_tokens": 200, - "max_new_tokens": 200, - "tp_sizes": [1, 2, 4, 8], - "batch_sizes": [1, 8], - "input_tokens": [1024, 2048], - "dataset": "llm_perf/datasets/merged_52_test.csv", - "perf_time": 100 -} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json b/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json new file mode 100644 index 00000000..ee6737b6 --- /dev/null +++ b/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json @@ -0,0 +1,18 @@ +{ + "model": "chatglm2-6b", + "test_accuracy": true, + "min_tp_size": 1, + "accuracy_config": { + "dataset": "llm_perf/datasets/merged_52_test.csv", + "min_new_tokens": 1, + "max_new_tokens": 512 + }, + "test_perf": true, + "perf_config": { + "tp_sizes": [1, 2, 4, 8], + "batch_sizes": [1, 4, 8, 16, 24, 32], + "input_tokens": [1024, 2048], + "output_tokens": 200, + "perf_time": 100 + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/workloads/falcon-180b.json b/byte_infer_perf/llm_perf/workloads/falcon-180b.json deleted file mode 100644 index 33a16066..00000000 --- a/byte_infer_perf/llm_perf/workloads/falcon-180b.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "model": "falcon-180b", - "test_accuracy": true, - "test_perf": true, - "min_new_tokens": 200, - "max_new_tokens": 200, - "tp_sizes": [8], - "batch_sizes": [1, 8], - "input_tokens": [1024, 2048], - "dataset": "llm_perf/datasets/merged_52_test.csv", - "perf_time": 100 -} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/workloads/falcon-torch-bf16-180b.json b/byte_infer_perf/llm_perf/workloads/falcon-torch-bf16-180b.json new file mode 100644 index 00000000..6668da31 --- /dev/null +++ b/byte_infer_perf/llm_perf/workloads/falcon-torch-bf16-180b.json @@ -0,0 +1,18 @@ +{ + "model": "falcon-180b", + "test_accuracy": false, + "min_tp_size": 8, + "accuracy_config": { + "dataset": "llm_perf/datasets/merged_52_test.csv", + "min_new_tokens": 1, + "max_new_tokens": 512 + }, + "test_perf": true, + "perf_config": { + "tp_sizes": [1, 2, 4, 8], + "batch_sizes": [1, 4, 8, 16, 24, 32], + "input_tokens": [1024, 2048], + "output_tokens": 200, + "perf_time": 100 + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/workloads/llama3-torch-bf16-70b.json b/byte_infer_perf/llm_perf/workloads/llama3-torch-bf16-70b.json new file mode 100644 index 00000000..eb0394b4 --- /dev/null +++ b/byte_infer_perf/llm_perf/workloads/llama3-torch-bf16-70b.json @@ -0,0 +1,18 @@ +{ + "model": "llama3-70b", + "test_accuracy": true, + "min_tp_size": 8, + "accuracy_config": { + "dataset": "llm_perf/datasets/merged_52_test.csv", + "min_new_tokens": 1, + "max_new_tokens": 512 + }, + "test_perf": true, + "perf_config": { + "tp_sizes": [1, 2, 4, 8], + "batch_sizes": [1, 4, 8, 16, 24, 32], + "input_tokens": [1024, 2048], + "output_tokens": 200, + "perf_time": 100 + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/workloads/mixtral-torch-bf16-8x22b.json b/byte_infer_perf/llm_perf/workloads/mixtral-torch-bf16-8x22b.json new file mode 100644 index 00000000..26bcd20d --- /dev/null +++ b/byte_infer_perf/llm_perf/workloads/mixtral-torch-bf16-8x22b.json @@ -0,0 +1,18 @@ +{ + "model": "mixtral-8x22b", + "test_accuracy": false, + "min_tp_size": 8, + "accuracy_config": { + "dataset": "llm_perf/datasets/merged_52_test.csv", + "min_new_tokens": 1, + "max_new_tokens": 512 + }, + "test_perf": true, + "perf_config": { + "tp_sizes": [1, 2, 4, 8], + "batch_sizes": [1, 4, 8, 16, 24, 32], + "input_tokens": [1024, 2048], + "output_tokens": 200, + "perf_time": 100 + } +} \ No newline at end of file