Skip to content

Commit

Permalink
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forwa…
Browse files Browse the repository at this point in the history
…rd for zbv
  • Loading branch information
duanjunwen committed Oct 10, 2024
1 parent 72b507a commit e234dfa
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 43 deletions.
235 changes: 192 additions & 43 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,52 +679,201 @@ def mixtral_for_causal_lm_forward(
)
past_key_values = None

if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss

if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# interleaved
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss

if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# 1f1b or otherwise
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss

if not return_dict:
output = (logits,) + outputs[1:]
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
out["past_router_logits"] = outputs["past_router_logits"]
return out

# if stage_manager.is_last_stage():
# hidden_states = outputs[0]
# logits = self.lm_head(hidden_states)
# logits = logits.float()

# loss = None
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# # Flatten the tokens
# loss_fct = CrossEntropyLoss()
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
# shift_labels = shift_labels.view(-1)
# # Enable model parallelism
# shift_labels = shift_labels.to(shift_logits.device)
# loss = loss_fct(shift_logits, shift_labels)

# aux_loss = None
# if output_router_logits:
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
# if labels is not None:
# loss += self.router_aux_loss_coef * aux_loss

# if not return_dict:
# output = (logits,) + outputs[1:]
# if output_router_logits:
# output = (aux_loss,) + output
# return (loss,) + output if loss is not None else output

# return MoeCausalLMOutputWithPast(
# loss=loss,
# aux_loss=aux_loss,
# logits=logits,
# past_key_values=None,
# hidden_states=outputs[0],
# attentions=None,
# router_logits=outputs[-1],
# )
# else:
# out = {}
# hidden_states = outputs.get("hidden_states")
# out["hidden_states"] = hidden_states
# if output_router_logits:
# out["past_router_logits"] = outputs["past_router_logits"]
# return out


def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
seed_all(10086)

torch_model = MixtralModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
Expand Down

0 comments on commit e234dfa

Please sign in to comment.