From 72b507a7beeb01f8407c3a6ea76d49bf9e75f040 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 10 Oct 2024 06:19:51 +0000 Subject: [PATCH] [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv; --- colossalai/shardformer/modeling/mixtral.py | 254 +++++++++++++++++---- 1 file changed, 212 insertions(+), 42 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index df9b91da2559..d1e44aa5bebb 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -267,26 +267,98 @@ def mixtral_model_forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}") - if stage_manager.is_first_stage(): - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + if stage_manager.is_interleave: + if stage_manager.use_zbv: + # zbv + if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds + # interleaved + if stage_manager.is_first_stage(ignore_chunk=True): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device + # 1f1b or None + if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + ####### + # Attention, we support consider 1f1b, interleaved, zbv + ####### + + # # retrieve input_ids and inputs_embeds + # print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}") + # if stage_manager.is_first_stage(): + # # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + # elif input_ids is not None: + # batch_size, seq_length = input_ids.shape + # elif inputs_embeds is not None: + # batch_size, seq_length, _ = inputs_embeds.shape + # else: + # raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + # device = input_ids.device if input_ids is not None else inputs_embeds.device + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) + # hidden_states = inputs_embeds + # else: + # input_shape = hidden_states.shape[:-1] + # batch_size, seq_length = input_shape + # device = hidden_states.device seq_length_with_past = seq_length past_key_values_length = 0 @@ -390,8 +462,22 @@ def custom_forward(*inputs): if output_router_logits: all_router_logits += (layer_outputs[-1],) - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) + ####### + # Attention, we support consider 1f1b, interleaved, zbv + ####### + if stage_manager.is_interleave: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: + hidden_states = self.norm(hidden_states) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + hidden_states = self.norm(hidden_states) + else: + if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b + hidden_states = self.norm(hidden_states) + + # if stage_manager.is_last_stage(): + # hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: @@ -400,30 +486,114 @@ def custom_forward(*inputs): if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits - if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) + + ####### + # Attention, we support consider 1f1b, interleaved, zbv + ####### + if stage_manager.is_interleave: + if stage_manager.use_zbv: + # zbv + if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } + else: + # interlearved + if stage_manager.is_last_stage(ignore_chunk=True): + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } else: - if output_router_logits: - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } + # 1f1b or other + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) else: - return { - "hidden_states": hidden_states, - } + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + else: + return { + "hidden_states": hidden_states, + } + + # if stage_manager.is_last_stage(): + # if not return_dict: + # return tuple( + # v + # for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + # if v is not None + # ) + # return MoeModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # router_logits=all_router_logits, + # ) + # else: + # if output_router_logits: + # return { + # "hidden_states": hidden_states, + # "past_router_logits": all_router_logits, + # } + # else: + # return { + # "hidden_states": hidden_states, + # } @staticmethod def mixtral_for_causal_lm_forward(