diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 8cd0fadd167f..abce504e9d61 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,9 +10,15 @@ BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel -from transformers.utils import logging +from transformers.models.bert.modeling_bert import ( + BertForPreTraining, + BertForPreTrainingOutput, + BertLMHeadModel, + BertModel, +) +from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -22,24 +28,23 @@ def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - # this is from the previous stage - hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage ): # TODO: add explaination of the output here. r""" @@ -85,10 +90,6 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -119,14 +120,29 @@ def bert_model_forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): hidden_states = self.embeddings( input_ids=input_ids, @@ -135,18 +151,8 @@ def bert_model_forward( inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - # inherit from bert_layer + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -221,34 +227,35 @@ def custom_forward(*inputs): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) - # output of non-first and non-last stages: - if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - # return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + # output of non-first and non-last stages: must be a dict + else: + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__( + self, + stage_manager: PipelineStageManager, + num_layers: int, + ): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -266,10 +273,10 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) def bert_for_pretraining_forward( @@ -285,53 +292,74 @@ def bert_for_pretraining_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, -) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - +): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } class BertForPreTrainingPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: """ @@ -352,25 +380,144 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.model) + module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.forward) + + +def bert_lmhead_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} + + +class BertLMHeadModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) + + def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: """ - divide layers into stages + get pipeline layers for current stage """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + return [] + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py new file mode 100644 index 000000000000..afbea49c1829 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertForPreTraining + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_for_pretraining_forward(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_for_pretraining_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_for_pretraining_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_for_pretraining_policy(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_for_pretraining_forward() + test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py new file mode 100644 index 000000000000..d41eddc74dff --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertLMHeadModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_lmhead_forward(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_lmhead_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_lmhead_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_lmhead_policy(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_lmhead_forward() + test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index cf5dc95feb8c..92485072a5e4 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -39,11 +39,11 @@ def check_bert_model_forward(): if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2, 12, 3, 3)) + attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -78,7 +78,7 @@ def check_bert_model_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: