Skip to content

Commit

Permalink
attempt to simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 23, 2024
1 parent 051590d commit c088159
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
16 changes: 7 additions & 9 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention, RingAttention
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.attn import AttnMaskType
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig

Expand Down Expand Up @@ -133,7 +132,7 @@ def gpt2_model_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: Optional[bool] = True,
force_sp_gather: Optional[bool] = True,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
Expand Down Expand Up @@ -229,15 +228,14 @@ def gpt2_model_forward(
# Ring Attention's special zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
# Get cu_seqlens
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
if not attention_mask.bool().all():
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
# Other sp modes
elif disable_pp or stage_manager.is_first_stage():
else:
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
Expand All @@ -247,7 +245,6 @@ def gpt2_model_forward(
elif sp_mode == "ring_attn":
# Later stages already received split hidden states
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)

del attention_mask

# Going through held blocks.
Expand Down Expand Up @@ -301,8 +298,9 @@ def gpt2_model_forward(
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

# When sequence parallelism is done, gather the output tensor in forward and split it in backward
if (disable_pp or stage_manager.is_last_stage()) and shard_config.enable_sequence_parallelism:
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
if disable_pp or stage_manager.is_last_stage():
if gather_output:
hidden_states = gather_sp_output(
hidden_states,
sp_dim=1,
Expand Down Expand Up @@ -399,7 +397,7 @@ def gpt2_lmhead_model_forward(
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_output_gather=False,
force_sp_gather=False,
)

# If not at the last stage, return hidden_states as in GPT2Model
Expand Down
13 changes: 7 additions & 6 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
Expand Down Expand Up @@ -58,7 +57,7 @@ def llama_model_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: bool = True, # Set to false only when computing cross entropy
force_sp_gather: bool = True, # Set to false only when computing cross entropy
):
logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -144,11 +143,12 @@ def llama_model_forward(
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)

# Support SP + PP. Later stages have already received the split input.
if disable_pp or stage_manager.is_first_stage():
split_input = disable_pp or stage_manager.is_first_stage()
if split_input:
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
if not attention_mask.bool().all():
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
Expand Down Expand Up @@ -218,9 +218,10 @@ def llama_model_forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
if disable_pp or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
if gather_output:
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)

# add hidden states from the last decoder layer
Expand Down Expand Up @@ -333,7 +334,7 @@ def llama_for_causal_lm_forward(
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_output_gather=False,
force_sp_gather=False,
)
past_key_values = None

Expand Down

0 comments on commit c088159

Please sign in to comment.