Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Nov 14, 2024
1 parent 0aab2c6 commit d58960c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
)
from transformers.models.mixtral.modeling_mixtral import (
_CONFIG_FOR_DOC,
MIXTRAL_INPUTS_DOCSTRING,
)
from transformers.modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)

from .cross_entropy import (
element_mul_kernel,
Expand Down Expand Up @@ -297,11 +281,6 @@ def forward(self, lin_weight, _input, target, bias=None):
self.reduction,
)

# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models?
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -435,143 +414,4 @@ def lce_forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward?
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def lce_forward_mixtral(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]

loss = None
logits = None

# patch change
if self.training and (labels is not None):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

# TODO: unique differing part to mixtral model forward
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
# TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss?
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
combine_triggers,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralForCausalLM,
MixtralAttention,
MixtralRMSNorm,
)
Expand All @@ -32,7 +31,6 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward_mixtral
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


Expand Down Expand Up @@ -95,11 +93,6 @@ def get_mp_rules(base_type):
"transformers.models.mixtral.modeling_mixtral",
),
),
ModelPatcherRule(
rule_id="mixtral-fused-lce",
trigger=ModelPatcherTrigger(check=MixtralForCausalLM),
forward=lce_forward_mixtral,
),
ModelPatcherRule(
rule_id="mixtral-rope",
import_and_maybe_reload=(
Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):

if (
not args.run_only_scenarios
and scenarios.slow
and scenario.slow
):
# unfiltered runs omit all "slow" marked scenarios
print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.")
Expand Down
37 changes: 24 additions & 13 deletions scripts/benchmarks/scenarios-liger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,18 @@
scenarios:
- name: full-finetuning
framework_config:
-
- foak-fast-kernels
- foak-fast-kernels-liger
arguments:
learning_rate: 2e-5
model_name_or_path:
- 'bigcode/gpt_bigcode-santacoder'
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mistral-7B-v0.1'
- 'meta-llama/Meta-Llama-3-8B'
torch_dtype: bfloat16
bf16: True

- name: standard-peft
framework_config:
-
- foak-fast-kernels
- foak-fast-kernels-liger
arguments:
Expand All @@ -66,13 +62,29 @@ scenarios:
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'
# - 'mistralai/Mistral-7B-v0.1'
- 'meta-llama/Meta-Llama-3-8B'

- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb-foak
- accelerated-peft-bnb-liger
arguments:
bf16: True
learning_rate: 2e-4
torch_dtype: bfloat16
peft_method: lora
r: 16
lora_alpha: 16
lora_dropout: 0.1
per_device_train_batch_size:
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
# - 'mistralai/Mistral-7B-v0.1'
- 'meta-llama/Meta-Llama-3-8B'

- name: accelerated-peft-gptq
framework_config:
- accelerated-peft-autogptq
- accelerated-peft-autogptq-foak
- accelerated-peft-autogptq-foak-liger
arguments:
Expand All @@ -85,6 +97,5 @@ scenarios:
lora_dropout: 0.1
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- 'TheBloke/Llama-2-70B-GPTQ'
# - 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'meta-llama/Meta-Llama-3-8B'

0 comments on commit d58960c

Please sign in to comment.