Skip to content

Commit

Permalink
[feat] update MixtralPipelineForwards --> mixtral_model_forward; supp…
Browse files Browse the repository at this point in the history
…ort zbv;
  • Loading branch information
duanjunwen committed Oct 10, 2024
1 parent 9ee80fc commit 72b507a
Showing 1 changed file with 212 additions and 42 deletions.
254 changes: 212 additions & 42 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,98 @@ def mixtral_model_forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# interleaved
if stage_manager.is_first_stage(ignore_chunk=True):
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
# 1f1b or None
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device

#######
# Attention, we support consider 1f1b, interleaved, zbv
#######

# # retrieve input_ids and inputs_embeds
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
# if stage_manager.is_first_stage():
# # retrieve input_ids and inputs_embeds
# if input_ids is not None and inputs_embeds is not None:
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
# elif input_ids is not None:
# batch_size, seq_length = input_ids.shape
# elif inputs_embeds is not None:
# batch_size, seq_length, _ = inputs_embeds.shape
# else:
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# hidden_states = inputs_embeds
# else:
# input_shape = hidden_states.shape[:-1]
# batch_size, seq_length = input_shape
# device = hidden_states.device

seq_length_with_past = seq_length
past_key_values_length = 0
Expand Down Expand Up @@ -390,8 +462,22 @@ def custom_forward(*inputs):
if output_router_logits:
all_router_logits += (layer_outputs[-1],)

if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
hidden_states = self.norm(hidden_states)

# if stage_manager.is_last_stage():
# hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
Expand All @@ -400,30 +486,114 @@ def custom_forward(*inputs):

if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)

#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
# interlearved
if stage_manager.is_last_stage(ignore_chunk=True):
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
# 1f1b or other
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
return {
"hidden_states": hidden_states,
}
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}

# if stage_manager.is_last_stage():
# if not return_dict:
# return tuple(
# v
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
# if v is not None
# )
# return MoeModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# router_logits=all_router_logits,
# )
# else:
# if output_router_logits:
# return {
# "hidden_states": hidden_states,
# "past_router_logits": all_router_logits,
# }
# else:
# return {
# "hidden_states": hidden_states,
# }

@staticmethod
def mixtral_for_causal_lm_forward(
Expand Down

0 comments on commit 72b507a

Please sign in to comment.