From e234dfa236e9f94f250c5858efdd0cd607326fdd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 10 Oct 2024 06:57:35 +0000 Subject: [PATCH] [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv --- colossalai/shardformer/modeling/mixtral.py | 235 ++++++++++++++---- .../test_schedule/test_zerobubble_pp.py | 2 + 2 files changed, 194 insertions(+), 43 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index d1e44aa5bebb..3709af54c486 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -679,52 +679,201 @@ def mixtral_for_causal_lm_forward( ) past_key_values = None - if stage_manager.is_last_stage(): - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + ####### + # Attention, we support consider 1f1b, interleaved, zbv + ####### + if stage_manager.is_interleave: + if stage_manager.use_zbv: + # zbv + if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out + else: + # interleaved + if stage_manager.is_last_stage(ignore_chunk=True): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out + else: + # 1f1b or otherwise + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None if labels is not None: - loss += self.router_aux_loss_coef * aux_loss + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss - if not return_dict: - output = (logits,) + outputs[1:] + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=None, - hidden_states=outputs[0], - attentions=None, - router_logits=outputs[-1], - ) - else: - out = {} - hidden_states = outputs.get("hidden_states") - out["hidden_states"] = hidden_states - if output_router_logits: - out["past_router_logits"] = outputs["past_router_logits"] - return out + out["past_router_logits"] = outputs["past_router_logits"] + return out + + # if stage_manager.is_last_stage(): + # hidden_states = outputs[0] + # logits = self.lm_head(hidden_states) + # logits = logits.float() + + # loss = None + # if labels is not None: + # # Shift so that tokens < n predict n + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # shift_logits = shift_logits.view(-1, self.config.vocab_size) + # shift_labels = shift_labels.view(-1) + # # Enable model parallelism + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) + + # aux_loss = None + # if output_router_logits: + # aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + # if labels is not None: + # loss += self.router_aux_loss_coef * aux_loss + + # if not return_dict: + # output = (logits,) + outputs[1:] + # if output_router_logits: + # output = (aux_loss,) + output + # return (loss,) + output if loss is not None else output + + # return MoeCausalLMOutputWithPast( + # loss=loss, + # aux_loss=aux_loss, + # logits=logits, + # past_key_values=None, + # hidden_states=outputs[0], + # attentions=None, + # router_logits=outputs[-1], + # ) + # else: + # out = {} + # hidden_states = outputs.get("hidden_states") + # out["hidden_states"] = hidden_states + # if output_router_logits: + # out["past_router_logits"] = outputs["past_router_logits"] + # return out def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 384ed649055c..1e8f1392e470 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -786,6 +786,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): seed_all(10086) torch_model = MixtralModel(config).to(dtype).cuda() + # TODO: Support MixtralForCausalLM + # torch_model = MixtralForCausalLM(config).to(dtype).cuda() torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) # init schedule h, a, s = config.hidden_size, config.num_attention_heads, 1024