From c08815966bc1e04debaa652933007718b6b78e12 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 23 Aug 2024 03:54:18 +0000 Subject: [PATCH] attempt to simplify code --- colossalai/shardformer/modeling/gpt2.py | 16 +++++++--------- colossalai/shardformer/modeling/llama.py | 13 +++++++------ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 019adc9095aa..16b2526cf9d2 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -23,7 +23,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention, RingAttention from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward -from colossalai.shardformer.layer.attn import AttnMaskType from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -133,7 +132,7 @@ def gpt2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - force_sp_output_gather: Optional[bool] = True, + force_sp_gather: Optional[bool] = True, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -229,15 +228,14 @@ def gpt2_model_forward( # Ring Attention's special zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - # Get cu_seqlens - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + if not attention_mask.bool().all(): hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) # Other sp modes - elif disable_pp or stage_manager.is_first_stage(): + else: if sp_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, @@ -247,7 +245,6 @@ def gpt2_model_forward( elif sp_mode == "ring_attn": # Later stages already received split hidden states _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) - del attention_mask # Going through held blocks. @@ -301,8 +298,9 @@ def gpt2_model_forward( all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism is done, gather the output tensor in forward and split it in backward - if (disable_pp or stage_manager.is_last_stage()) and shard_config.enable_sequence_parallelism: - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) + if disable_pp or stage_manager.is_last_stage(): + if gather_output: hidden_states = gather_sp_output( hidden_states, sp_dim=1, @@ -399,7 +397,7 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, - force_sp_output_gather=False, + force_sp_gather=False, ) # If not at the last stage, return hidden_states as in GPT2Model diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1aed7d9d4906..8b71b5962588 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -25,7 +25,6 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -58,7 +57,7 @@ def llama_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - force_sp_output_gather: bool = True, # Set to false only when computing cross entropy + force_sp_gather: bool = True, # Set to false only when computing cross entropy ): logger = logging.get_logger(__name__) @@ -144,11 +143,12 @@ def llama_model_forward( attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP. Later stages have already received the split input. - if disable_pp or stage_manager.is_first_stage(): + split_input = disable_pp or stage_manager.is_first_stage() + if split_input: # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + if not attention_mask.bool().all(): hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) @@ -218,9 +218,10 @@ def llama_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) + gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + if gather_output: hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer @@ -333,7 +334,7 @@ def llama_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, - force_sp_output_gather=False, + force_sp_gather=False, ) past_key_values = None