Skip to content

Commit

Permalink
all tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 1, 2024
1 parent 59e533d commit 05017d3
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 220 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ def configure(
zero_stage = 0

if not isinstance(model, ModelWrapper):
# Can't use pp (frequent grad accumulation) with torch ddp
# Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
Expand Down
263 changes: 137 additions & 126 deletions colossalai/shardformer/layer/attn.py

Large diffs are not rendered by default.

45 changes: 33 additions & 12 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -151,7 +150,7 @@ def cross_entropy_1d(


def dist_cross_entropy(
labels: torch.Tensor, # [B, S]
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
out_features: int,
Expand All @@ -169,35 +168,54 @@ def dist_cross_entropy(
sp_mode = shard_config.sequence_parallelism_mode
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism

bs, seq_len = labels.shape
is_packed = labels.dim() == 2
if is_packed:
bs, seq_len = labels.shape
else:
# padded sequence
seq_len = labels.shape[-1]
logits = logits.reshape(-1, *logits.shape[2:])
seq_dim = 0

# Shift labels to predict the next token, and remove the tail logit predicting <EOS>
is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
if is_sp:
# shift only once
# shift only once: either before splitting or on the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
labels = labels[..., 1:]
# Split labels when logits are split
if split_labels_here:
labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]

# Pad to the same shape across all ranks in TP all_reduce
if sp_rank == sp_size - 1:
logits = logits[..., :-1, :]
# Pad logits and labels to the same shape across all ranks for TP all_reduce
if is_tp and parallel_output:
pad_shape = [0] * logits.dim() * 2
pad_shape[-3] = 1 # Right side, dim = -2
logits = F.pad(logits, pad_shape, value=_IGNORE_IDX)
labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX)
# If is packed sequence (label dim is 1), then each seq already has the end label token padded.
# NOTE: torch.cat is faster than F.pad...
pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
logits = torch.cat([logits, padding], dim=seq_dim)

pad_shape = (labels.shape[0], 1) if is_packed else (1,)
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
labels = torch.cat([labels, padding], dim=seq_dim)
# pad_shape = [0] * labels.dim() * 2
# pad_shape[1] = 1
# labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX)
else:
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
num_nonzero = (labels != _IGNORE_IDX).sum()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
try:
assert (
labels.shape == logits.shape[:-1]
), f"label shape {labels.shape} does not match logit shape {logits.shape}"
except:
pass

# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
Expand All @@ -218,7 +236,10 @@ def dist_cross_entropy(
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
loss = loss_fct(logits, labels)
try:
loss = loss_fct(logits, labels)
except:
pass

# Reduce loss instead of gathering logits over seq dim for savings
if split_labels_here or sp_mode == "ring_attn":
Expand Down
27 changes: 17 additions & 10 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ def create_randomizer_with_offset(
return Randomizer(seed=base_seed)


def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False
):
def split_batch_zigzag(batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1):
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
Expand All @@ -304,9 +302,7 @@ def split_batch_zigzag(
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
varlen (bool): If the input is padded (aka "packing" mode), such that
sequences in a batch have different lengths, and we need to unpad and
split each sequence evenly by sp_size.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
Expand All @@ -329,7 +325,7 @@ def split_batch_zigzag(
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
tensor = tensor.index_select(seq_dim, indices).contiguous()
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous()
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])

if len(batch) == 1:
return batch[0]
Expand All @@ -342,6 +338,7 @@ def split_varlen_zigzag(
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
Expand All @@ -353,6 +350,7 @@ def split_varlen_zigzag(
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
Expand All @@ -373,7 +371,10 @@ def split_varlen_zigzag(
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
local_seq = torch.zeros(shape, dtype=dtype, device=device)
if is_label:
local_seq = torch.full(shape, -100, dtype=dtype, device=device)
else:
local_seq = torch.zeros(shape, dtype=dtype, device=device)
else:
total_seqlen = cu_seqlens[-1]
assert (
Expand All @@ -389,12 +390,18 @@ def split_varlen_zigzag(
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"

if is_2d:
seq = packed_seq[j][:seqlen].chunk(2 * sp_size, dim=0)
seq = packed_seq[j][:seqlen]
if is_label:
seq[0] = -100
seq = seq.chunk(2 * sp_size, dim=0)
half = seqlen // sp_size // 2
local_seq[j][:half] = seq[sp_rank]
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
else:
seq = packed_seq[start:end].chunk(sp_size * 2)
seq = packed_seq[start:end]
if is_label:
seq[0] = -100
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])

if is_2d:
Expand Down
88 changes: 55 additions & 33 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig

from ..layer import ColoAttention, RingAttention, dist_cross_entropy, get_pad_info
from ..layer import ColoAttention, RingAttention, dist_cross_entropy

_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]

Expand Down Expand Up @@ -132,33 +132,37 @@ def llama_model_forward(
position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
_, attention_mask, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
elif shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attn_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
try:
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
invert=(sp_mode != "ring_attn"),
)
except:
pass
else:
attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)

# Support SP + PP
# TODO: support padded casual cu_seqlens across stages
if stage_manager.is_first_stage():
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info(
attn_mask["attention_mask"].squeeze(1).any(dim=-1)
) # [B, 1, Sq, Skv] -> [B, Sq]

batch = [hidden_states, position_ids]
# inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group)
hidden_states, position_ids = split_batch_zigzag(batch, sp_group)
if attention_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
hidden_states, attention_mask, position_ids = RingAttention.prepare_varlen_batch(
attention_mask["attention_mask"].squeeze(1).any(-1), sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)

elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
Expand Down Expand Up @@ -199,7 +203,7 @@ def llama_model_forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attn_mask,
attention_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -209,7 +213,7 @@ def llama_model_forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attn_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down Expand Up @@ -312,9 +316,13 @@ def llama_for_causal_lm_forward(
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False

if stage_manager.is_first_stage():
if shard_config.sequence_parallelism_mode == "ring_attn":
labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group)
if shard_config.sequence_parallelism_mode == "ring_attn":
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
else:
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
Expand Down Expand Up @@ -545,8 +553,12 @@ def forward(
)

kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
try:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
except:
pass

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand All @@ -560,6 +572,7 @@ def forward(
attn_output = RingAttention.attention(
query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask
)

elif shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
Expand Down Expand Up @@ -670,27 +683,30 @@ def forward(

if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
mask_info: dict = ColoAttention.prepare_attn_kwargs(
attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
invert=(sp_mode != "ring_attn"),
)

else:
mask_info: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
attention_mask: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if mask_info["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
inputs_embeds, position_ids, mask_info = RingAttention.prepare_varlen_batch(
inputs_embeds, mask_info["attention_mask"], sp_group, position_ids
if attention_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
inputs_embeds, attention_mask, position_ids = RingAttention.prepare_varlen_batch(
attention_mask["attention_mask"].squeeze(1).any(-1), sp_group, inputs_embeds, position_ids
)
else:
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
mask_info = {"attention_mask_type": mask_info["attention_mask_type"]} # drop redundant tensors
attention_mask = {
"attention_mask_type": attention_mask["attention_mask_type"]
} # drop redundant tensors

elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
Expand All @@ -710,7 +726,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
mask_info,
attention_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -721,7 +737,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=mask_info,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down Expand Up @@ -813,7 +829,13 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if shard_config.sequence_parallelism_mode == "ring_attn":
labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group)
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
else:
# [B, max_seq_len // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down
Loading

0 comments on commit 05017d3

Please sign in to comment.