From 57d4fabdf036ca9b6c59f1fe53d41b3d15d70b64 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 30 Jun 2023 19:19:24 +0800 Subject: [PATCH 01/17] add pipeline policy and bert forward to be done --- colossalai/pipeline/policy/__init__.py | 20 ++ colossalai/pipeline/policy/base.py | 108 +++++++ colossalai/pipeline/policy/bert.py | 295 +++++++++++++++++++ colossalai/pipeline/policy/llama.py | 258 ++++++++++++++++ tests/test_pipeline/test_policy/test_bert.py | 57 ++++ tests/test_pipeline/test_stage_manager.py | 2 +- 6 files changed, 739 insertions(+), 1 deletion(-) create mode 100644 colossalai/pipeline/policy/__init__.py create mode 100644 colossalai/pipeline/policy/base.py create mode 100644 colossalai/pipeline/policy/bert.py create mode 100644 colossalai/pipeline/policy/llama.py create mode 100644 tests/test_pipeline/test_policy/test_bert.py diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py new file mode 100644 index 000000000000..cd372a28b79c --- /dev/null +++ b/colossalai/pipeline/policy/__init__.py @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy +from .llama import LlamaForCausalLM, LlamaForCausalLMPolicy + +POLICY_MAP: Dict[Type[Module], Type[Policy]] = { + LlamaForCausalLM: LlamaForCausalLMPolicy, +} + + +def pipeline_parallelize(model: Module, stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + if type(model) not in POLICY_MAP: + raise NotImplementedError(f"Policy for {type(model)} not implemented") + policy = POLICY_MAP[type(model)](stage_manager) + return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py new file mode 100644 index 000000000000..ad595a04b1b0 --- /dev/null +++ b/colossalai/pipeline/policy/base.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional, Tuple + +from colossalai.lazy import LazyTensor +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: + """Setup model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers + """ + hold_params = set() + hold_buffers = set() + + def init_layer(layer: Module): + for p in layer.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + hold_params.add(p) + for b in layer.buffers(): + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + hold_buffers.add(b) + + hold_layers = self.get_hold_layers(module) + + for layer in hold_layers: + init_layer(layer) + + hold_params_dict = {} + hold_buffers_dict = {} + + # release other tensors + for n, p in module.named_parameters(): + if p in hold_params: + hold_params_dict[n] = p + else: + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + p.storage().resize_(0) + for n, b in module.named_buffers(): + if b in hold_buffers: + hold_buffers_dict[n] = b + else: + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + # FIXME(ver217): use meta tensor may be better + b.storage().resize_(0) + return hold_params_dict, hold_buffers_dict + + def replace_forward(self, module: Module) -> None: + """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict + + Args: + module (Module): _description_ + """ + raise NotImplementedError + + def get_hold_layers(self, module: Module) -> List[Module]: + """Get layers that should be hold in current stage. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + raise NotImplementedError + + def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + """Parallelize model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters + """ + hold_params, hold_buffers = self.setup_model(module) + self.replace_forward(module) + shared_params = self.get_shared_params(module) + return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py new file mode 100644 index 000000000000..00aabf3984ef --- /dev/null +++ b/colossalai/pipeline/policy/bert.py @@ -0,0 +1,295 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import (BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions) +from transformers.models.bert.modeling_bert import BertModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + +def bert_model_forward(self:BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + #labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage + ) : + #TODO: add explaination of the output here. + + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # assure that the input is embedding_output and is the hidden_states of previous stages. + + hidden_states = input_ids if input_ids is not None else None + if stage_manager.is_first_stage(): + hidden_states= self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + + encoder_outputs = None + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + next_decoder_cache = () if use_cache else None + + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + attention_mask = extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + ### + print('where is the model now',start_layer,idx,end_layer) + print('what stage is now',stage_manager.stage) + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + if stage_manager.stage == 1: + if hidden_states is not None : + print('shape of hidden_states',hidden_states.shape) + if attention_mask is not None : + print('shape of attention_mask',attention_mask.shape) + ## TODO: check for this layer_head_mask + if layer_head_mask is not None : + print('shape of layer_head_mask',layer_head_mask.shape) + if encoder_hidden_states is not None : + print('shape of encoder_hidden_states',encoder_hidden_states.shape) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# class BertModelPolicy(Policy): +# def get_hold_layers(self, module: BertModel) -> List[Module]: +# # get pipeline layers for curerent stage +# hold_layers = [] +# if self.stage_manager.is_first_stage(): +# hold_layers.append(module.embeddings) +# #Fix: num_layers_per_stage should be calculated based on the number of layers in the model +# num_layers_per_stage = len(module.encoder.layer) // self.stage_manager.num_stages + +# hold_layers.extend(module.encoder.layer[self.stage_manager.stage* +# num_layers_per_stage : (self.stage_manager.stage+1)* num_layers_per_stage]) +# if self.stage_manager.is_last_stage(): +# hold_layers.append(module.pooler) + +# return hold_layers + +# def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: +# if id(module.embeddings.parameters) == id(module.pooler.parameters) +# return [dict(module.embeddings.named_parameters())] +# return [] +# def replace_forward(self, module: Module) -> None: +# return super().replace_forward(module) + +''' +def bert_pretraining_model_forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + pass +''' \ No newline at end of file diff --git a/colossalai/pipeline/policy/llama.py b/colossalai/pipeline/policy/llama.py new file mode 100644 index 000000000000..d83683ccb264 --- /dev/null +++ b/colossalai/pipeline/policy/llama.py @@ -0,0 +1,258 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutput, + CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import (LlamaForCausalLM, + LlamaModel) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + + +def llama_model_forward(self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, # this is set by partial + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + ) -> Union[CausalLMOutput, Tuple]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if output_attentions: + logger.warning_once('`output_attentions=True` is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('`output_hidden_states=True` is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('`use_cache=True` is not supported for pipeline models at the moment.') + use_cache = False + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + if stage_manager.is_first_stage(): + inputs_embeds = self.embed_tokens(input_ids) + else: + inputs_embeds = hidden_states + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + # this function only uses inputs_embeds' device, dtype, and shape, it's safe to use hidden_state + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + num_layers_per_stage = len(self.layers) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for idx, decoder_layer in enumerate(self.layers[start_layer:end_layer], start=start_layer): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + # TODO(ver217): return_dict is not supported for pipeline models at the moment. + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def llama_for_causal_lm_forward(self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, # this is set by partial + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + hidden_states = outputs[0] + if not stage_manager.is_last_stage(): + return dict(hidden_states=hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + ) + + +class LlamaForCausalLMPolicy(Policy): + def get_hold_layers(self, module: LlamaForCausalLM) -> List[Module]: + hold_layers = [] + + if self.stage_manager.is_first_stage(): + hold_layers.append(module.model.embed_tokens) + num_layers_per_stage = len(module.model.layers) // self.stage_manager.num_stages + hold_layers.extend(module.model.layers[self.stage_manager.stage * + num_layers_per_stage: (self.stage_manager.stage + 1) * num_layers_per_stage]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.model.norm) + hold_layers.append(module.lm_head) + + return hold_layers + + def get_shared_params(self, module: LlamaForCausalLM) -> List[Dict[int, Tensor]]: + if id(module.model.embed_tokens.weight) == id(module.lm_head.weight): + # tie weights + return [{0: module.model.embed_tokens.weight, self.stage_manager.num_stages - 1: module.lm_head.weight}] + return [] + + def replace_forward(self, module: LlamaForCausalLM) -> None: + module.model.forward = MethodType(partial(llama_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(llama_for_causal_lm_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py new file mode 100644 index 000000000000..0e27802da13e --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -0,0 +1,57 @@ +import torch +import pytest +import torch.distributed as dist +from colossalai.cluster import ProcessGroupMesh +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn + +from colossalai.pipeline.policy.bert import bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from transformers.models.bert.modeling_bert import BertModel + +def check_bert_model_forward(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + #print(rank) + + x = torch.randint(0, 1000, (2, 3)) + attention_mask = torch.ones_like(x) + + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output) + assert output[0].shape == (2, 3, 768) + # assert output[1].shape == (2, 768) + + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_forward() + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_forward(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_bert_model_forward() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..67a2e90532e2 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() From 8300f451863649375db9fdb062e9cfe0990e5a6f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 14:53:32 +0800 Subject: [PATCH 02/17] add bertmodel pipeline forward and make tests --- colossalai/pipeline/policy/bert.py | 97 ++++--- colossalai/pipeline/policy/llama.py | 258 ------------------- tests/test_pipeline/test_policy/test_bert.py | 23 +- 3 files changed, 61 insertions(+), 317 deletions(-) delete mode 100644 colossalai/pipeline/policy/llama.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 00aabf3984ef..1b9cdaecf9eb 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -57,6 +57,7 @@ def bert_model_forward(self:BertModel, If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ + # debugging # preprocess: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -69,15 +70,26 @@ def bert_model_forward(self:BertModel, else: use_cache = False - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -88,8 +100,7 @@ def bert_model_forward(self:BertModel, logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') use_cache = False - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 @@ -105,10 +116,24 @@ def bert_model_forward(self:BertModel, else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states= self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -120,27 +145,7 @@ def bert_model_forward(self:BertModel, else: encoder_extended_attention_mask = None - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - # assure that the input is embedding_output and is the hidden_states of previous stages. - hidden_states = input_ids if input_ids is not None else None - if stage_manager.is_first_stage(): - hidden_states= self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - - - encoder_outputs = None #inherit from bert_layer all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -159,22 +164,19 @@ def bert_model_forward(self:BertModel, start_layer = stage_manager.stage * num_layers_per_stage end_layer = (stage_manager.stage + 1) * num_layers_per_stage + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): if stage_manager.is_first_stage() and idx == 0: - attention_mask = extended_attention_mask + encoder_attention_mask=encoder_extended_attention_mask if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[idx] if head_mask is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None - - ### - print('where is the model now',start_layer,idx,end_layer) - print('what stage is now',stage_manager.stage) - - if self.encoder.gradient_checkpointing and self.encoder.training: - + + if self.encoder.gradient_checkpointing and self.encoder.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -190,16 +192,6 @@ def custom_forward(*inputs): encoder_attention_mask, ) else: - if stage_manager.stage == 1: - if hidden_states is not None : - print('shape of hidden_states',hidden_states.shape) - if attention_mask is not None : - print('shape of attention_mask',attention_mask.shape) - ## TODO: check for this layer_head_mask - if layer_head_mask is not None : - print('shape of layer_head_mask',layer_head_mask.shape) - if encoder_hidden_states is not None : - print('shape of encoder_hidden_states',encoder_hidden_states.shape) layer_outputs = encoder_layer( hidden_states, attention_mask, @@ -226,9 +218,8 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + return (sequence_output, pooled_output) + layer_outputs[1:] - #output of non-first and non-last stages: if not return_dict: return tuple(v diff --git a/colossalai/pipeline/policy/llama.py b/colossalai/pipeline/policy/llama.py deleted file mode 100644 index d83683ccb264..000000000000 --- a/colossalai/pipeline/policy/llama.py +++ /dev/null @@ -1,258 +0,0 @@ -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutput, - CausalLMOutputWithPast) -from transformers.models.llama.modeling_llama import (LlamaForCausalLM, - LlamaModel) -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def llama_model_forward(self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, # this is set by partial - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - ) -> Union[CausalLMOutput, Tuple]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if output_attentions: - logger.warning_once('`output_attentions=True` is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('`output_hidden_states=True` is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('`use_cache=True` is not supported for pipeline models at the moment.') - use_cache = False - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - if stage_manager.is_first_stage(): - inputs_embeds = self.embed_tokens(input_ids) - else: - inputs_embeds = hidden_states - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - # this function only uses inputs_embeds' device, dtype, and shape, it's safe to use hidden_state - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - num_layers_per_stage = len(self.layers) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for idx, decoder_layer in enumerate(self.layers[start_layer:end_layer], start=start_layer): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - # TODO(ver217): return_dict is not supported for pipeline models at the moment. - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def llama_for_causal_lm_forward(self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, # this is set by partial - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - ) - - hidden_states = outputs[0] - if not stage_manager.is_last_stage(): - return dict(hidden_states=hidden_states) - - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - ) - - -class LlamaForCausalLMPolicy(Policy): - def get_hold_layers(self, module: LlamaForCausalLM) -> List[Module]: - hold_layers = [] - - if self.stage_manager.is_first_stage(): - hold_layers.append(module.model.embed_tokens) - num_layers_per_stage = len(module.model.layers) // self.stage_manager.num_stages - hold_layers.extend(module.model.layers[self.stage_manager.stage * - num_layers_per_stage: (self.stage_manager.stage + 1) * num_layers_per_stage]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.model.norm) - hold_layers.append(module.lm_head) - - return hold_layers - - def get_shared_params(self, module: LlamaForCausalLM) -> List[Dict[int, Tensor]]: - if id(module.model.embed_tokens.weight) == id(module.lm_head.weight): - # tie weights - return [{0: module.model.embed_tokens.weight, self.stage_manager.num_stages - 1: module.lm_head.weight}] - return [] - - def replace_forward(self, module: LlamaForCausalLM) -> None: - module.model.forward = MethodType(partial(llama_model_forward, stage_manager=self.stage_manager), module.model) - module.forward = MethodType(partial(llama_for_causal_lm_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 0e27802da13e..4f9af46c485e 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -30,15 +30,26 @@ def check_bert_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - #print(rank) + # print(rank) x = torch.randint(0, 1000, (2, 3)) - attention_mask = torch.ones_like(x) + hidden_states = torch.randint(0,1000,(2,3,768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2,12,3,3)) + output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('end the training') + print(output) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, - stage_manager=stage_manager) - print(output) - assert output[0].shape == (2, 3, 768) # assert output[1].shape == (2, 768) From 246b6d3b7a3eba548f88eb4be604b85ff516bf60 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 16:27:18 +0800 Subject: [PATCH 03/17] add Bert_Policy and test for policy --- colossalai/pipeline/policy/__init__.py | 5 +- colossalai/pipeline/policy/bert.py | 78 +++++++++++++++----- tests/test_pipeline/test_policy/test_bert.py | 42 ++++++++++- 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py index cd372a28b79c..cb4b99803119 100644 --- a/colossalai/pipeline/policy/__init__.py +++ b/colossalai/pipeline/policy/__init__.py @@ -6,10 +6,9 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .base import Policy -from .llama import LlamaForCausalLM, LlamaForCausalLMPolicy - +from .bert import BertModel,BertModelPolicy POLICY_MAP: Dict[Type[Module], Type[Policy]] = { - LlamaForCausalLM: LlamaForCausalLMPolicy, + BertModel: BertModelPolicy, } diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 1b9cdaecf9eb..d9ee53748126 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -240,29 +240,69 @@ def custom_forward(*inputs): cross_attentions=all_cross_attentions, ) +# The layer partition policy for bertmodel +class BertModelPolicy(Policy): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int,num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers,num_stages) -# class BertModelPolicy(Policy): -# def get_hold_layers(self, module: BertModel) -> List[Module]: -# # get pipeline layers for curerent stage -# hold_layers = [] -# if self.stage_manager.is_first_stage(): -# hold_layers.append(module.embeddings) -# #Fix: num_layers_per_stage should be calculated based on the number of layers in the model -# num_layers_per_stage = len(module.encoder.layer) // self.stage_manager.num_stages + def get_hold_layers(self, module: BertModel) -> List[Module]: + # get pipeline layers for current stage + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.embeddings) + num_layers_per_stage_accumulated = self.convert_into_accumulated() + hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) -# hold_layers.extend(module.encoder.layer[self.stage_manager.stage* -# num_layers_per_stage : (self.stage_manager.stage+1)* num_layers_per_stage]) -# if self.stage_manager.is_last_stage(): -# hold_layers.append(module.pooler) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.pooler) -# return hold_layers + return hold_layers -# def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: -# if id(module.embeddings.parameters) == id(module.pooler.parameters) -# return [dict(module.embeddings.named_parameters())] -# return [] -# def replace_forward(self, module: Module) -> None: -# return super().replace_forward(module) + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_model_forward,stage_manager=self.stage_manager), module.model) + + # divide layers into stages + def distribute_layers(self, num, stage_num) -> List[int]: + quotient = num // stage_num + remainder = num % stage_num + + # calculate the num_layers per stage + layers_per_stage = [quotient] * stage_num + + # deal with the rest layers + if remainder > 0: + middle_stages = (stage_num-1) // 2 + right_extra = remainder // 2 + left_extra = remainder - right_extra + + #divide the rest part + left=0 + right=0 + while left_extra > 0: + layers_per_stage[middle_stages - left] += 1 + left_extra -= 1 + left+= 1 + while right_extra > 0 : + layers_per_stage[middle_stages + right + 1] += 1 + right_extra -= 1 + right+=1 + return layers_per_stage + def convert_into_accumulated(self) -> List[int]: + '''convert a array into accumulated array''' + acc = 0 + layers_per_stage_accumulated=[] + for num in self.layers_per_stage: + acc += num + layers_per_stage_accumulated.append(acc) + return layers_per_stage_accumulated + + ''' def bert_pretraining_model_forward( diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 4f9af46c485e..4545bc795d40 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -5,7 +5,7 @@ import colossalai from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.pipeline.policy.bert import bert_model_forward +from colossalai.pipeline.policy.bert import bert_model_forward,BertModelPolicy from colossalai.pipeline.stage_manager import PipelineStageManager from transformers.models.bert.modeling_bert import BertModel @@ -52,17 +52,53 @@ def check_bert_model_forward(): # assert output[1].shape == (2, 768) +def check_bert_model_policy(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertModelPolicy(stage_manager,len(model.encoder.layer),2) + assert model_policy.layers_per_stage == [6,6] + layers=model_policy.get_hold_layers(model) + for layer in layers: + print(layer) -def run_dist(rank, world_size, port): +def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') check_bert_model_forward() +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_policy() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_forward(): - spawn(run_dist, 4) + spawn(run_dist_model, 4) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_policy(): + spawn(run_dist_policy, 4) if __name__ == "__main__": test_bert_model_forward() + test_bert_model_policy() \ No newline at end of file From db0a1f150df079c6027f430b32293f12ca3018ac Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 17:05:17 +0800 Subject: [PATCH 04/17] update formatting --- colossalai/pipeline/policy/__init__.py | 7 +- colossalai/pipeline/policy/bert.py | 405 ++++++++++--------- tests/test_pipeline/test_policy/test_bert.py | 48 ++- 3 files changed, 237 insertions(+), 223 deletions(-) diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py index cb4b99803119..fd9e6e04588e 100644 --- a/colossalai/pipeline/policy/__init__.py +++ b/colossalai/pipeline/policy/__init__.py @@ -6,13 +6,16 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .base import Policy -from .bert import BertModel,BertModelPolicy +from .bert import BertModel, BertModelPolicy + POLICY_MAP: Dict[Type[Module], Type[Policy]] = { BertModel: BertModelPolicy, } -def pipeline_parallelize(model: Module, stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: +def pipeline_parallelize( + model: Module, + stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: if type(model) not in POLICY_MAP: raise NotImplementedError(f"Policy for {type(model)} not implemented") policy = POLICY_MAP[type(model)](stage_manager) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index d9ee53748126..9fab35241767 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -5,10 +5,12 @@ import torch from torch import Tensor from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import (BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions) -from transformers.models.bert.modeling_bert import BertModel +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.bert.modeling_bert import BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -17,7 +19,9 @@ logger = logging.get_logger(__name__) -def bert_model_forward(self:BertModel, + +def bert_model_forward( + self: BertModel, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -27,17 +31,16 @@ def bert_model_forward(self:BertModel, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - #labels: Optional[torch.LongTensor] = None, + #labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage - ) : - #TODO: add explaination of the output here. - - r""" + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage +): + #TODO: add explaination of the output here. + r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. @@ -57,197 +60,195 @@ def bert_model_forward(self:BertModel, If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - # debugging - # preprocess: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache + # debugging + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] else: - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded else: - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - hidden_states = hidden_states if hidden_states is not None else None - if stage_manager.is_first_stage(): - hidden_states= self.embeddings( + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage - #inherit from bert_layer - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask - if self.encoder.gradient_checkpointing and self.encoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - next_decoder_cache = () if use_cache else None - - #calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - #layer_outputs - layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): - if stage_manager.is_first_stage() and idx == 0: - encoder_attention_mask=encoder_extended_attention_mask - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[idx] if head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) - #end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + return custom_forward - if stage_manager.is_last_stage(): - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + layer_outputs[1:] - - #output of non-first and non-last stages: + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return tuple(v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - #return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + return (sequence_output, pooled_output) + layer_outputs[1:] + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int,num_stages: int): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers,num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: - # get pipeline layers for current stage + """ + get pipeline layers for current stage + """ hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) @@ -255,53 +256,55 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: num_layers_per_stage_accumulated[self.stage_manager.stage]]) - + if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) return hold_layers - + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' pass + def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward,stage_manager=self.stage_manager), module.model) + module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - # divide layers into stages def distribute_layers(self, num, stage_num) -> List[int]: - quotient = num // stage_num - remainder = num % stage_num + """ + divide layers into stages + """ + quotient = num // stage_num + remainder = num % stage_num # calculate the num_layers per stage layers_per_stage = [quotient] * stage_num # deal with the rest layers if remainder > 0: - middle_stages = (stage_num-1) // 2 - right_extra = remainder // 2 - left_extra = remainder - right_extra - + middle_stages = (stage_num - 1) // 2 + right_extra = remainder // 2 + left_extra = remainder - right_extra + #divide the rest part - left=0 - right=0 + left = 0 + right = 0 while left_extra > 0: layers_per_stage[middle_stages - left] += 1 left_extra -= 1 - left+= 1 - while right_extra > 0 : - layers_per_stage[middle_stages + right + 1] += 1 + left += 1 + while right_extra > 0: + layers_per_stage[middle_stages + right + 1] += 1 right_extra -= 1 - right+=1 + right += 1 return layers_per_stage + def convert_into_accumulated(self) -> List[int]: - '''convert a array into accumulated array''' acc = 0 - layers_per_stage_accumulated=[] + layers_per_stage_accumulated = [] for num in self.layers_per_stage: acc += num layers_per_stage_accumulated.append(acc) return layers_per_stage_accumulated - ''' @@ -323,4 +326,4 @@ def bert_pretraining_model_forward( ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: pass -''' \ No newline at end of file +''' diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 4545bc795d40..c92f7f6c34c0 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -1,13 +1,14 @@ -import torch import pytest +import torch import torch.distributed as dist -from colossalai.cluster import ProcessGroupMesh +from transformers.models.bert.modeling_bert import BertModel + import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.pipeline.policy.bert import bert_model_forward,BertModelPolicy -from colossalai.pipeline.stage_manager import PipelineStageManager -from transformers.models.bert.modeling_bert import BertModel def check_bert_model_forward(): model = BertModel.from_pretrained('bert-base-uncased') @@ -24,34 +25,36 @@ def check_bert_model_forward(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) #print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0,1000,(2,3,768)).to(torch.float32) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2,12,3,3)) - output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 768) print('end the training') print(output) - + # assert output[1].shape == (2, 768) + def check_bert_model_policy(): model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 @@ -67,16 +70,16 @@ def check_bert_model_policy(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) #print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager,len(model.encoder.layer),2) - assert model_policy.layers_per_stage == [6,6] - layers=model_policy.get_hold_layers(model) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) for layer in layers: print(layer) @@ -85,20 +88,25 @@ def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') check_bert_model_forward() + def run_dist_policy(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_policy() + check_bert_model_policy() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_forward(): spawn(run_dist_model, 4) + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_policy(): spawn(run_dist_policy, 4) + if __name__ == "__main__": + """test the bert model forward and bert model policy""" test_bert_model_forward() - test_bert_model_policy() \ No newline at end of file + test_bert_model_policy() From 9f57067d72a39f95ebbea01e6a2208bcd532caa2 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 17:12:31 +0800 Subject: [PATCH 05/17] update formatting --- colossalai/pipeline/policy/bert.py | 35 ++++++++++++++---------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 9fab35241767..c862e9297044 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,7 +10,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertModel +from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -307,23 +307,20 @@ def convert_into_accumulated(self) -> List[int]: return layers_per_stage_accumulated -''' def bert_pretraining_model_forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - - ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: pass -''' From dac6427377aadd62632834fdecb51a534284c28f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 18:19:13 +0800 Subject: [PATCH 06/17] update the code --- colossalai/pipeline/policy/bert.py | 37 ++++++++---------------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index c862e9297044..15be48b47b4e 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -2,6 +2,7 @@ from types import MethodType from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch from torch import Tensor from torch.nn import CrossEntropyLoss, Module @@ -252,7 +253,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = self.convert_into_accumulated() + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: num_layers_per_stage_accumulated[self.stage_manager.stage]]) @@ -269,43 +270,23 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - def distribute_layers(self, num, stage_num) -> List[int]: + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """ divide layers into stages """ - quotient = num // stage_num - remainder = num % stage_num + quotient = num_layers // num_stages + remainder = num_layers % num_stages # calculate the num_layers per stage - layers_per_stage = [quotient] * stage_num + layers_per_stage = [quotient] * num_stages # deal with the rest layers if remainder > 0: - middle_stages = (stage_num - 1) // 2 - right_extra = remainder // 2 - left_extra = remainder - right_extra - - #divide the rest part - left = 0 - right = 0 - while left_extra > 0: - layers_per_stage[middle_stages - left] += 1 - left_extra -= 1 - left += 1 - while right_extra > 0: - layers_per_stage[middle_stages + right + 1] += 1 - right_extra -= 1 - right += 1 + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 return layers_per_stage - def convert_into_accumulated(self) -> List[int]: - acc = 0 - layers_per_stage_accumulated = [] - for num in self.layers_per_stage: - acc += num - layers_per_stage_accumulated.append(acc) - return layers_per_stage_accumulated - def bert_pretraining_model_forward( self, From 8b30a0223088f9f4b1efdd577aee7d8f6e604aba Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 10:22:05 +0800 Subject: [PATCH 07/17] fix bugs --- colossalai/pipeline/policy/bert.py | 91 ++++++++++++++++- colossalai/pipeline/policy/bloom.py | 153 ++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+), 4 deletions(-) create mode 100644 colossalai/pipeline/policy/bloom.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 15be48b47b4e..6f912d2c6b80 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -11,7 +11,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel +from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -288,8 +288,8 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: return layers_per_stage -def bert_pretraining_model_forward( - self, +def bert_for_pretraining_forward( + self: BertForPreTraining, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -304,4 +304,87 @@ def bert_pretraining_model_forward( hidden_states: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - pass + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertForPreTrainingPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py new file mode 100644 index 000000000000..8dffcd8f9af5 --- /dev/null +++ b/colossalai/pipeline/policy/bloom.py @@ -0,0 +1,153 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) From 585eb9d9470d3f56a6d6c84c02bca406187cffb0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 10:42:03 +0800 Subject: [PATCH 08/17] fix name confilt --- .../test_policy/{test_bert.py => test_bert_model.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_pipeline/test_policy/{test_bert.py => test_bert_model.py} (100%) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert_model.py similarity index 100% rename from tests/test_pipeline/test_policy/test_bert.py rename to tests/test_pipeline/test_policy/test_bert_model.py From 27fb80409570f3aa3c7ae4e96544e2b3c0e53c43 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 16:53:20 +0800 Subject: [PATCH 09/17] add bloom model and policy ,revise the base class of policy --- colossalai/pipeline/policy/base.py | 23 +++- colossalai/pipeline/policy/bert.py | 86 ++++++------- colossalai/pipeline/policy/bloom.py | 110 ++++++++++++---- .../test_policy/test_bert_model.py | 4 +- .../test_policy/test_bloom_model.py | 119 ++++++++++++++++++ 5 files changed, 268 insertions(+), 74 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index ad595a04b1b0..9bfce15a83ab 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -1,13 +1,14 @@ from typing import Any, Dict, List, Optional, Tuple -from colossalai.lazy import LazyTensor from torch import Tensor from torch.nn import Module, Parameter +from colossalai.lazy import LazyTensor from colossalai.pipeline.stage_manager import PipelineStageManager class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -93,7 +94,8 @@ def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: """ raise NotImplementedError - def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + def parallelize_model(self, + module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: """Parallelize model for pipeline parallel Args: @@ -106,3 +108,20 @@ def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[ self.replace_forward(module) shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 6f912d2c6b80..002814e9014e 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -22,25 +22,26 @@ def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - #labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + # labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + # this is from the previous stage + hidden_states: Optional[torch.FloatTensor] = None, ): - #TODO: add explaination of the output here. + # TODO: add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -93,6 +94,7 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -144,7 +146,7 @@ def bert_model_forward( else: encoder_extended_attention_mask = None - #inherit from bert_layer + # inherit from bert_layer all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -156,12 +158,12 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - #calculate the num_layers + # calculate the num_layers num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages start_layer = stage_manager.stage * num_layers_per_stage end_layer = (stage_manager.stage + 1) * num_layers_per_stage - #layer_outputs + # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): if stage_manager.is_first_stage() and idx == 0: @@ -206,12 +208,13 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + \ + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - #end of a stage loop + # end of a stage loop sequence_output = layer_outputs[0] if layer_outputs is not None else None if stage_manager.is_last_stage(): @@ -219,7 +222,7 @@ def custom_forward(*inputs): if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] - #output of non-first and non-last stages: + # output of non-first and non-last stages: if not return_dict: return tuple(v for v in [ hidden_states, @@ -229,7 +232,7 @@ def custom_forward(*inputs): all_cross_attentions, ] if v is not None) - #return dict is not supported at this moment + # return dict is not supported at this moment return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -243,8 +246,9 @@ def custom_forward(*inputs): class BertModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = super().distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -254,9 +258,9 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) + hold_layers.extend( + module.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) @@ -270,23 +274,6 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - def bert_for_pretraining_forward( self: BertForPreTraining, @@ -356,9 +343,10 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.bert.embeddings) num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) + hold_layers.extend( + module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - + 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) if self.stage_manager.is_last_stage(): hold_layers.append(module.cls) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index 8dffcd8f9af5..25b5039760bf 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from types import MethodType from typing import Dict, List, Optional, Tuple, Union @@ -14,6 +15,8 @@ from .base import Policy +logger = logging.get_logger(__name__) + def bloom_model_forward( self: BloomModel, @@ -26,6 +29,8 @@ def bloom_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -44,29 +49,45 @@ def bloom_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + # extra recording tensor should be generated in the first stage presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -77,11 +98,13 @@ def bloom_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values[0][0].shape[2] # source_len seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) @@ -90,13 +113,19 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # causal_mask is constructed every stage and its input is passed through different stages causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # calculate the num_layers + num_layers_per_stage = len(self.h) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -130,24 +159,63 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] + if use_cache is True: presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) - # Add last hidden state - hidden_states = self.ln_f(hidden_states) + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + # TODO: deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + +class BloomModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BloomModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.word_embeddings) + hold_layers.append(module.word_embeddings_layernorm) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.h[num_layers_per_stage_accumulated[self.stage_manager.stage - + 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + if self.stage_manager.is_last_stage(): + hold_layers.append(module.ln_f) + + return hold_layers + + def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index c92f7f6c34c0..b757f6813153 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -27,7 +27,7 @@ def check_bert_model_forward(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() @@ -72,7 +72,7 @@ def check_bert_model_policy(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py new file mode 100644 index 000000000000..5ba92d734590 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bloom_model.py @@ -0,0 +1,119 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bloom import BloomConfig, BloomModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bloom_model_forward(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32) + if stage_manager.is_first_stage(): + attention_mask = torch.ones_like(x) + output = bloom_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bloom_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bloom_model_policy(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2) + assert model_policy.layers_per_stage == [1, 1] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bloom model forward and bloom model policy""" + test_bloom_model_forward() + test_bloom_model_policy() From 3ea0ba4627e218224f79dff6ab5aa47683d616e3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 18:06:18 +0800 Subject: [PATCH 10/17] revise --- colossalai/pipeline/policy/base.py | 3 ++- colossalai/pipeline/policy/bert.py | 11 ++++++----- colossalai/pipeline/policy/bloom.py | 11 ++++++----- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index 9bfce15a83ab..8da70dd43362 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -109,7 +109,8 @@ def parallelize_model(self, shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: """ divide layers into stages """ diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 002814e9014e..0ec30d41129c 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -248,7 +248,7 @@ class BertModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -257,11 +257,12 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend( - module.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] + end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + hold_layers.extend(module.encoder.layer[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index 25b5039760bf..56337b26f333 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -193,7 +193,7 @@ class BloomModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BloomModel) -> List[Module]: """ @@ -203,10 +203,11 @@ def get_hold_layers(self, module: BloomModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.word_embeddings) hold_layers.append(module.word_embeddings_layernorm) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.h[num_layers_per_stage_accumulated[self.stage_manager.stage - - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] + end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + hold_layers.extend(module.h[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.ln_f) From edb02b268df05529669eff4dc8b9f97c0ca99be3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 18:36:13 +0800 Subject: [PATCH 11/17] revision --- colossalai/pipeline/policy/base.py | 15 ++++++++++++++- colossalai/pipeline/policy/bert.py | 13 ++----------- colossalai/pipeline/policy/bloom.py | 9 +++------ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index c390e436b5a1..9736f1004fe4 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -1,14 +1,15 @@ from typing import Any, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor from torch.nn import Module, Parameter from colossalai.lazy import LazyTensor - from colossalai.pipeline.stage_manager import PipelineStageManager class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -126,3 +127,15 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage + + @staticmethod + def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index bc4d6e549762..a1efe238573c 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -22,7 +22,6 @@ def bert_model_forward( - self: BertModel, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -95,7 +94,6 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') @@ -213,7 +211,6 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -225,7 +222,6 @@ def custom_forward(*inputs): if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] - # output of non-first and non-last stages: if not return_dict: return tuple(v for v in [ @@ -236,7 +232,6 @@ def custom_forward(*inputs): all_cross_attentions, ] if v is not None) - # return dict is not supported at this moment return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -262,11 +257,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] - end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] - + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) hold_layers.extend(module.encoder.layer[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) @@ -280,6 +271,7 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + def bert_for_pretraining_forward( self: BertForPreTraining, input_ids: Optional[torch.Tensor] = None, @@ -352,7 +344,6 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) - if self.stage_manager.is_last_stage(): hold_layers.append(module.cls) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index ebd086df67a8..71d2913fc3aa 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -15,9 +15,9 @@ from .base import Policy - logger = logging.get_logger(__name__) + def bloom_model_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, @@ -187,7 +187,7 @@ def custom_forward(*inputs): attentions=all_self_attentions, ) - + class BloomModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): @@ -203,10 +203,8 @@ def get_hold_layers(self, module: BloomModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.word_embeddings) hold_layers.append(module.word_embeddings_layernorm) - num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) - start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] - end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) hold_layers.extend(module.h[start_idx:end_idx]) if self.stage_manager.is_last_stage(): @@ -220,4 +218,3 @@ def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) - From 369df2cf4fef581e08f1da501f7ae070bb8f57d8 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 19:20:10 +0800 Subject: [PATCH 12/17] add bert_for_pretraining --- colossalai/pipeline/policy/bert.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index a1efe238573c..8cd0fadd167f 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -290,8 +290,8 @@ def bert_for_pretraining_forward( ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( + outputs = bert_model_forward( + self.bert, input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -304,7 +304,8 @@ def bert_for_pretraining_forward( ) sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + if stage_manager.is_last_stage(): + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) total_loss = None if labels is not None and next_sentence_label is not None: @@ -339,12 +340,12 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.bert.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend( - module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) hold_layers.append(module.cls) return hold_layers From 0319c8bc4f803578cedbfa54726ee9fec9eae650 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 5 Jul 2023 12:23:57 +0800 Subject: [PATCH 13/17] add bert_for_pretraining forward and policy --- colossalai/pipeline/policy/bert.py | 112 +++++++++-------- .../test_bert_for_pretraining_model.py | 118 ++++++++++++++++++ 2 files changed, 178 insertions(+), 52 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 8cd0fadd167f..d8b665ec6c24 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -285,51 +285,76 @@ def bert_for_pretraining_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + hidden_states = outputs[0] if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + else: + if not return_dict: + return tuple(v for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) class BertForPreTrainingPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager self.layers_per_stage = self.distribute_layers(num_layers, num_stages) @@ -355,22 +380,5 @@ def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor pass def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.model) - - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage + module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.forward) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py new file mode 100644 index 000000000000..4d764704ccba --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertForPreTraining + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_for_pretraining_forward(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 2: + attention_mask = torch.ones_like(x) + output = bert_for_pretraining_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + elif stage_manager.stage == 1: + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_for_pretraining_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_for_pretraining_policy(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_for_pretraining_forward() + test_bert_for_pretraining_policy() From 29ef3807accadd42023f1bc8ba2880756fe8e858 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 10:35:38 +0800 Subject: [PATCH 14/17] fix typos --- colossalai/pipeline/policy/bert.py | 280 ++++++++++++++---- .../test_bert_for_pretraining_model.py | 8 +- .../test_policy/test_bert_lmhead_model.py | 118 ++++++++ .../test_policy/test_bert_model.py | 4 +- 4 files changed, 340 insertions(+), 70 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bert_lmhead_model.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index d8b665ec6c24..85cb0b0af585 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,9 +10,15 @@ BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel -from transformers.utils import logging +from transformers.models.bert.modeling_bert import ( + BertForPreTraining, + BertForPreTrainingOutput, + BertLMHeadModel, + BertModel, +) +from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -21,25 +27,38 @@ logger = logging.get_logger(__name__) +class BertModelIntermediateOutput(ModelOutput): + """ + Class for the intermediate output of bert model and bert-based model + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the previous stage. + NOTE: This is different from the base model. + """ + + hidden_states: torch.FloatTensor = None + attention_mask: Optional[torch.Tensor] = None + + def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - # this is from the previous stage - hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage ): # TODO: add explaination of the output here. r""" @@ -85,10 +104,6 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -119,14 +134,29 @@ def bert_model_forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): hidden_states = self.embeddings( input_ids=input_ids, @@ -135,18 +165,8 @@ def bert_model_forward( inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - # inherit from bert_layer + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -221,34 +241,34 @@ def custom_forward(*inputs): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) - # output of non-first and non-last stages: + # output of non-first and non-last stages: must be a dict if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - # return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + logger.warning_once('The output of intermediate stage should always be a dict') + + return BertModelIntermediateOutput(hidden_states=hidden_states,) # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__( + self, + stage_manager: PipelineStageManager, + num_layers: int, + ): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -287,7 +307,7 @@ def bert_for_pretraining_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, -) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: +): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -317,6 +337,7 @@ def bert_for_pretraining_forward( all_self_attentions = None all_cross_attentions = None hidden_states = outputs[0] + if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) @@ -342,21 +363,16 @@ def bert_for_pretraining_forward( else: if not return_dict: - return tuple(v for v in [ - hidden_states, - past_key_values, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) + logger.warning_once('The output of intermediate stage should always be a dict') + return BertModelIntermediateOutput(hidden_states=hidden_states,) class BertForPreTrainingPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: """ @@ -382,3 +398,139 @@ def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), module.forward) + + +def bert_lmhead_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + hidden_states = outputs[0] + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + if not return_dict: + return BertModelIntermediateOutput(hidden_states=hidden_states) + + +class BertLMHeadModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) + + def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 4d764704ccba..b170b52163c3 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -37,7 +37,7 @@ def check_bert_for_pretraining_forward(): x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 2: + if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_for_pretraining_forward(self=model, input_ids=x, @@ -46,8 +46,8 @@ def check_bert_for_pretraining_forward(): print(output[0].shape) assert output[0].shape == (2, 3, 768) print('start the training') - elif stage_manager.stage == 1: - attention_mask = torch.ones((2, 12, 3, 3)) + else: + attention_mask = torch.ones((2, 3)) output = bert_for_pretraining_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -83,7 +83,7 @@ def check_bert_for_pretraining_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer), 2) + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py new file mode 100644 index 000000000000..04a6aff80ff1 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertLMHeadModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_lmhead_forward(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_lmhead_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_lmhead_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_lmhead_policy(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_lmhead_forward() + test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index cf5dc95feb8c..5903434d97b8 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -43,7 +43,7 @@ def check_bert_model_forward(): assert output[0].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2, 12, 3, 3)) + attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -78,7 +78,7 @@ def check_bert_model_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: From 5cd2478db4f6159e8a12ad66a630c7d24d2b0395 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 12:10:15 +0800 Subject: [PATCH 15/17] cancel warning --- colossalai/pipeline/policy/bert.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 85cb0b0af585..85bd35962386 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -252,10 +252,9 @@ def custom_forward(*inputs): ) # output of non-first and non-last stages: must be a dict - if not return_dict: - logger.warning_once('The output of intermediate stage should always be a dict') - - return BertModelIntermediateOutput(hidden_states=hidden_states,) + else: + # intermediate stage always return dict + return BertModelIntermediateOutput(hidden_states=hidden_states,) # The layer partition policy for bertmodel @@ -362,8 +361,7 @@ def bert_for_pretraining_forward( ) else: - if not return_dict: - logger.warning_once('The output of intermediate stage should always be a dict') + # intermediate stage always return dict return BertModelIntermediateOutput(hidden_states=hidden_states,) @@ -502,8 +500,8 @@ def bert_lmhead_forward(self: BertLMHeadModel, cross_attentions=outputs.cross_attentions, ) else: - if not return_dict: - return BertModelIntermediateOutput(hidden_states=hidden_states) + # intermediate stage always return dict + return BertModelIntermediateOutput(hidden_states=hidden_states) class BertLMHeadModelPolicy(Policy): From ef528e6d7d61f416ff47dab0aa082462f3e73b64 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 14:17:49 +0800 Subject: [PATCH 16/17] change the imediate output to default dict --- colossalai/pipeline/policy/bert.py | 33 +++++++------------ .../test_bert_for_pretraining_model.py | 4 +-- .../test_policy/test_bert_lmhead_model.py | 4 +-- .../test_policy/test_bert_model.py | 4 +-- 4 files changed, 17 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 85bd35962386..ec6ab91b9365 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -27,20 +27,6 @@ logger = logging.get_logger(__name__) -class BertModelIntermediateOutput(ModelOutput): - """ - Class for the intermediate output of bert model and bert-based model - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the previous stage. - NOTE: This is different from the base model. - """ - - hidden_states: torch.FloatTensor = None - attention_mask: Optional[torch.Tensor] = None - - def bert_model_forward( self: BertModel, input_ids: Optional[torch.Tensor] = None, @@ -254,7 +240,9 @@ def custom_forward(*inputs): # output of non-first and non-last stages: must be a dict else: # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states,) + return { + 'hidden_states': hidden_states, + } # The layer partition policy for bertmodel @@ -288,7 +276,7 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: pass def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) def bert_for_pretraining_forward( @@ -335,8 +323,6 @@ def bert_for_pretraining_forward( all_hidden_states = None all_self_attentions = None all_cross_attentions = None - hidden_states = outputs[0] - if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) @@ -359,10 +345,13 @@ def bert_for_pretraining_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states,) + return { + 'hidden_states': hidden_states, + } class BertForPreTrainingPolicy(Policy): @@ -473,7 +462,6 @@ def bert_lmhead_forward(self: BertLMHeadModel, all_hidden_states = None all_self_attentions = None all_cross_attentions = None - hidden_states = outputs[0] if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -500,8 +488,9 @@ def bert_lmhead_forward(self: BertLMHeadModel, cross_attentions=outputs.cross_attentions, ) else: + hidden_states = outputs.get('hidden_states') # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states) + return {'hidden_states': hidden_states} class BertLMHeadModelPolicy(Policy): diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index b170b52163c3..afbea49c1829 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -43,8 +43,8 @@ def check_bert_for_pretraining_forward(): input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py index 04a6aff80ff1..d41eddc74dff 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -43,8 +43,8 @@ def check_bert_lmhead_forward(): input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index 5903434d97b8..92485072a5e4 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -39,8 +39,8 @@ def check_bert_model_forward(): if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) From e3e6c3bd6ae42646658a91b3571c408e1916b78e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 14:23:02 +0800 Subject: [PATCH 17/17] change the default output of get_shared_params --- colossalai/pipeline/policy/bert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index ec6ab91b9365..abce504e9d61 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -273,7 +273,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) @@ -380,7 +380,7 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), @@ -517,7 +517,7 @@ def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module)