From ab5294a0f8e12201c17ca5cd7892e0c0186aa6eb Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 21 Jul 2023 16:46:41 +0800 Subject: [PATCH] [shardformer]whisper support jit operator --- colossalai/shardformer/modeling/whisper.py | 147 ++++++++++++++++++ colossalai/shardformer/policies/whisper.py | 18 ++- tests/kit/model_zoo/transformers/whisper.py | 4 +- .../test_model/test_shard_whisper.py | 6 +- 4 files changed, 170 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 5558efc803f4..5599b4d4de5d 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -101,3 +101,150 @@ def forward( def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() + + +def get_jit_fused_whisper_encoder_layer_forward(): + + from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer + + def forward( + self: WhisperEncoderLayer, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward + + +def get_jit_fused_whisper_decoder_layer_forward(): + + from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer + + def forward( + self: WhisperDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 73fe0fa0319e..8f728b57efc5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -3,7 +3,12 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.whisper import get_whisper_flash_attention_forward +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.whisper import ( + get_jit_fused_whisper_decoder_layer_forward, + get_jit_fused_whisper_encoder_layer_forward, + get_whisper_flash_attention_forward, +) from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -190,6 +195,17 @@ def module_policy(self): 'forward': get_whisper_flash_attention_forward(), }) + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_whisper_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index b58716217cb5..c359a7ec664a 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -76,14 +76,14 @@ def data_gen_for_audio_classification(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_whisperForConditionalGeneration', +model_zoo.register(name='transformers_whisper_for_conditional_generation', model_fn=lambda: transformers.WhisperForConditionalGeneration(config), data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn_attr, model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_whisperWhisperForAudioClassification', +model_zoo.register(name='transformers_whisper_for_audio_classification', model_fn=lambda: transformers.WhisperForAudioClassification(config), data_gen_fn=data_gen_for_audio_classification, output_transform_fn=output_transform_fn, diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 9ebda3d67c40..16957b324313 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -74,13 +74,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention) + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()