diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 30ebf003145e6c..8c1cdd0b1a5557 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -1348,7 +1348,7 @@ def _unpack_router_logits(self, router_outputs): total_router_logits = [] total_expert_indexes = [] for router_output in router_outputs: - if router_output[0] is not None: + if len(router_output[0].shape) > 1: router_logits, expert_indexes = router_output total_router_logits.append(router_logits) total_expert_indexes.append(expert_indexes) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3b36b7732b0951..b8a90a6c1538f1 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -798,7 +798,7 @@ def forward( if isinstance(hidden_states, tuple): hidden_states, router_tuple = hidden_states else: - router_tuple = (None,) + router_tuple = (torch.tensor([0]),) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -1683,50 +1683,45 @@ def forward( decoder_z_loss = None decoder_aux_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # todo check in the config if router loss enables - - if output_router_logits: - # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder - encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits( - encoder_outputs.router_probs - ) + if output_router_logits: + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + if self.encoder.config.encoder_sparse_step > 1: + encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_outputs[-1]) encoder_z_loss = router_z_loss_func(encoder_router_logits) encoder_router_probs = nn.Softmax(dim=-1)(encoder_router_logits) encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes) + else: + encoder_z_loss = 0 + encoder_aux_loss = 0 - decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits( - decoder_outputs.router_probs - ) + if self.decoder.config.decoder_sparse_step > 1: + decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_outputs[-1]) decoder_z_loss = router_z_loss_func(decoder_router_logits) decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits) decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes) + else: + decoder_z_loss = 0 + decoder_aux_loss = 0 + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) # move labels to correct device to enable PP labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - if output_router_logits and labels is not None: + if output_router_logits: z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss) aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss) loss = loss + z_loss + aux_loss if not return_dict: output = (lm_logits,) - if output_router_logits: # only return the loss if they are not None - output += ( - encoder_z_loss, - encoder_aux_loss, - decoder_z_loss, - decoder_aux_loss, - *decoder_outputs[1:], - *encoder_outputs, - ) - else: - output += (*decoder_outputs[1:], *encoder_outputs) + if output_router_logits: + output += (encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss) + output += (*decoder_outputs[1:], *encoder_outputs) return ((loss,) + output) if loss is not None else output + return Seq2SeqMoEOutput( loss=loss, logits=lm_logits, @@ -1738,18 +1733,18 @@ def forward( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, + decoder_router_logits=decoder_outputs.router_probs, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, encoder_router_logits=encoder_outputs.router_probs, - decoder_router_logits=decoder_outputs.router_probs, ) def _unpack_router_logits(self, router_outputs): total_router_logits = [] total_expert_indexes = [] for router_output in router_outputs: - if router_output[0] is not None: + if len(router_output[0].shape) > 1: router_logits, expert_indexes = router_output total_router_logits.append(router_logits) total_expert_indexes.append(expert_indexes)