From c085a4562c704b94e76c17df1363a8fc6cd07e85 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 01:52:32 -0700 Subject: [PATCH] Cohere, Bug fixes (#984) * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * update token retrieval logic (#952) * Fix DPO (#947) * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * update hf token retrieval logic --------- Co-authored-by: Daniel Han * Update llama.py * get_token * Update README.md * Update gemma2.py * Update rms_layernorm.py * synchronize * Update gemma2.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * layernorm * Update rms_layernorm.py * Update gemma2.py * Update rms_layernorm.py * Update rms_layernorm.py * revert * Gemma * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update gemma2.py * Change UnslothTrainingArguments base class to SFTConfig (#979) * Cohere * Update trainer.py * Cohere * Cohere * New models * Update llama.py * Update llama.py * Update cohere.py * Update llama.py * Update cohere.py * retry * Update fast_lora.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * _apply_lora_mlp * Update _utils.py --------- Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com> Co-authored-by: Tuan Pham <82665400+vTuanpham@users.noreply.github.com> --- unsloth/kernels/cross_entropy_loss.py | 130 ++++--- unsloth/kernels/fast_lora.py | 30 +- unsloth/kernels/rms_layernorm.py | 4 +- unsloth/models/_utils.py | 2 +- unsloth/models/cohere.py | 473 ++++++++++++++++++++++++++ unsloth/models/llama.py | 81 ++++- unsloth/models/loader.py | 7 +- unsloth/models/mapper.py | 23 ++ unsloth/trainer.py | 7 +- 9 files changed, 690 insertions(+), 67 deletions(-) create mode 100644 unsloth/models/cohere.py diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index b8473e60..24e8002b 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,17 +19,22 @@ from transformers.models.llama.modeling_llama import logger -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -62,8 +67,11 @@ def _cross_entropy_forward( label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) logits = logits.to(tl.float32) c = tl.max(logits, 0) @@ -71,8 +79,10 @@ def _cross_entropy_forward( if label_idx != -100: x = tl.load(logits_ptr + label_idx) + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) + if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) loss = logsumexp - x.to(tl.float32) else: loss = 0.0 @@ -81,18 +91,23 @@ def _cross_entropy_forward( pass -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _chunked_cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -130,8 +145,11 @@ def _chunked_cross_entropy_forward( label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) logits = logits.to(tl.float32) c = tl.max(logits, 0) @@ -142,8 +160,10 @@ def _chunked_cross_entropy_forward( # Do the -x separately if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) + # Go logit scaling for Cohere: t * x + if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) + if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) loss = -1.0 * x.to(tl.float32) else: loss = 0.0 @@ -153,17 +173,22 @@ def _chunked_cross_entropy_forward( pass -@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_backward( logits_ptr, logits_row_stride, dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) @@ -195,6 +220,13 @@ def _cross_entropy_backward( dloss = 0.0 x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + + # Do logit scaling for Cohere + if DO_LOGIT_SCALING: + # d/dx [s * x] = s + x = x * LOGIT_SCALE + pass + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) @@ -210,6 +242,11 @@ def _cross_entropy_backward( y, # exp(x - logsumexp) ) + if DO_LOGIT_SCALING: + # d/dx [s * x] = s + y = y * LOGIT_SCALE + pass + if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) y = y * (1.0 - partial*partial) @@ -224,14 +261,15 @@ def _cross_entropy_backward( class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod - def forward(ctx, logits, labels, logit_softcapping = 0): + def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - DO_SOFTCAPPING = (logit_softcapping != 0) + DO_SOFTCAPPING = (logit_softcapping != 0) + DO_LOGIT_SCALING = (logit_scaling != 0) if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral @@ -243,11 +281,13 @@ def forward(ctx, logits, labels, logit_softcapping = 0): losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - num_warps = num_warps, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = num_warps, ) else: # For large vocabs > 65336 like Gemma 256K @@ -258,12 +298,14 @@ def forward(ctx, logits, labels, logit_softcapping = 0): losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - num_warps = 32, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = 32, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately @@ -275,6 +317,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0): ctx.save_for_backward(logits, logsumexp, labels) ctx.DO_SOFTCAPPING = DO_SOFTCAPPING ctx.logit_softcapping = logit_softcapping + ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING + ctx.logit_scaling = logit_scaling return losses pass @@ -292,19 +336,26 @@ def backward(ctx, dlosses): dlosses, dlosses.stride(0), logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, - SOFTCAP = ctx.logit_softcapping, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + SOFTCAP = ctx.logit_softcapping, + DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, + LOGIT_SCALE = ctx.logit_scaling, num_warps = 8, ) - return logits, None, None, + return logits, None, None, None, pass pass @torch._disable_dynamo -def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0): +def fast_cross_entropy_loss( + logits, + labels, + logit_softcapping = 0, + logit_scaling = 0, +): """ Arguments: logits: (batch, seq_len, vocab_size) @@ -319,6 +370,7 @@ def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0): logits.view(batch*seq_len, d), labels.view(-1), logit_softcapping, + logit_scaling, ) n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index 8f410179..2177b43b 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -68,7 +68,8 @@ def forward(ctx, X : torch.Tensor, gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - _forward_function, _backward_function,): + _forward_function, _backward_function, + inplace = True,): dtype = X.dtype e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS) @@ -84,6 +85,7 @@ def forward(ctx, X : torch.Tensor, ) ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g) + ctx.inplace = inplace return i pass @@ -131,7 +133,7 @@ def backward(ctx, dY : torch.Tensor): # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) - dX = torch.matmul(df, upW.t(), out = X) + dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) @@ -147,13 +149,13 @@ def backward(ctx, dY : torch.Tensor): None, None, d_gateA.t(), d_gateB.t(), None, \ None, None, d_upA.t(), d_upB.t(), None, \ None, None, d_downA.t(), d_downB.t(), None, \ - None, None, # _backward and _forward + None, None, None, # _backward and _forward and inplace pass pass from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel -def apply_lora_mlp_swiglu(self, X): +def apply_lora_mlp_swiglu(self, X, inplace = True): gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) @@ -161,13 +163,14 @@ def apply_lora_mlp_swiglu(self, X): gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,) + swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel, + inplace,) return out pass from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel -def apply_lora_mlp_geglu_exact(self, X): +def apply_lora_mlp_geglu_exact(self, X, inplace = True): gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj) downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) @@ -175,7 +178,8 @@ def apply_lora_mlp_geglu_exact(self, X): gateW, gateW_quant, gateA, gateB, gateS, upW, upW_quant, upA, upB, upS, downW, downW_quant, downA, downB, downS, - geglu_exact_forward_kernel, geglu_exact_backward_kernel,) + geglu_exact_forward_kernel, geglu_exact_backward_kernel, + inplace,) return out pass @@ -229,7 +233,8 @@ class LoRA_QKV(torch.autograd.Function): def forward(ctx, X : torch.Tensor, QW, QW_quant, QA, QB, QS, KW, KW_quant, KA, KB, KS, - VW, VW_quant, VA, VB, VS,): + VW, VW_quant, VA, VB, VS, + inplace = True): dtype = X.dtype Q = matmul_lora(X, QW, QW_quant, QA, QB, QS) @@ -242,6 +247,7 @@ def forward(ctx, X : torch.Tensor, VW, VW_quant, VS, ) ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,) + ctx.inplace = inplace return Q, K, V pass @@ -286,7 +292,7 @@ def backward(ctx, dQ, dK, dV): # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) - dX = torch.matmul(dQ, QW.t(), out = X) + dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) @@ -308,12 +314,13 @@ def backward(ctx, dQ, dK, dV): return dX.view(batch, seq_len, hd), \ None, None, d_QA.t(), d_QB.t(), None, \ None, None, d_KA.t(), d_KB.t(), None, \ - None, None, d_VA.t(), d_VB.t(), None + None, None, d_VA.t(), d_VB.t(), None, \ + None, pass pass -def apply_lora_qkv(self, X): +def apply_lora_qkv(self, X, inplace = True): QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) @@ -321,6 +328,7 @@ def apply_lora_qkv(self, X): QW, QW_quant, QA, QB, QS, KW, KW_quant, KA, KB, KS, VW, VW_quant, VA, VB, VS, + inplace, ) return Q, K, V pass diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index f26e5965..ac5beb5a 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -186,7 +186,9 @@ def backward(ctx, dY): def fast_rms_layernorm(layernorm, X, gemma = False): W = layernorm.weight - eps = layernorm.variance_epsilon + eps = layernorm.variance_epsilon if \ + hasattr(layernorm, "variance_epsilon") \ + else layernorm.eps out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) return out pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1c48e8e5..ea9a0c53 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -295,7 +295,7 @@ def patch_mistral_nemo_config(config): send_to_device, ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) - accelerate.utils.operations.send_to_device = _fixed_send_to_device + # accelerate.utils.operations.send_to_device = _fixed_send_to_device pass pass # ============================================= diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py new file mode 100644 index 00000000..aa0bcb55 --- /dev/null +++ b/unsloth/models/cohere.py @@ -0,0 +1,473 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .llama import * +from ._utils import __version__ +try: + from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereDecoderLayer, + CohereModel, + CohereForCausalLM, + CohereRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, + ) +except: + from packaging.version import Version + transformers_version = Version(transformers_version) + if not transformers_version >= Version("4.42"): + raise ImportError( + f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\ + f"The minimum required version is 4.42.3.\n"\ + f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\ + f"to obtain the latest transformers build, then restart this session."\ + ) + pass +pass + +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) +# For Pytorch 2.1.1 +try: + from transformers.models.cohere.modeling_cohere import ( + CohereSdpaAttention, + CohereFlashAttention2, + ) +except: + CohereSdpaAttention = CohereAttention + CohereFlashAttention2 = CohereAttention +pass + + +def fast_layernorm_inference(self, X, out_weight = None): + XX = X.to(torch.float32, copy = True) + XX -= X.mean(-1, keepdim = True) + variance = XX.square().mean(-1, keepdim = True) + variance += self.variance_epsilon + XX *= variance.rsqrt_() + out_weight[:] = self.weight + XX *= out_weight + return XX.to(X.dtype) +pass + + +# QK norm in Cohere +def CohereAttention_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention + del self.q_norm_out_weight + del self.k_norm_out_weight + pass + + bsz, q_len, _ = hidden_states.size() + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + assert(n_kv_heads * n_groups == n_heads) + + Q, K, V = self.apply_qkv(self, hidden_states) + Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) + K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + if self.use_qk_norm: + Q = fast_layernorm_compiled(self.q_norm, Q) + K = fast_layernorm_compiled(self.k_norm, K) + pass + + kv_seq_len = K.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if position_ids is None: + cos = self.rotary_emb.cos_cached + sin = self.rotary_emb.sin_cached + Q, K = fast_rope_embedding(Q, K, cos, sin) + else: + cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + pass + + if past_key_value is not None: + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) + pass + past_key_value = (K, V) if use_cache else None + + # Attention module + if (not HAS_FLASH_ATTENTION and attention_mask is None): + # Xformers memory efficient attention + # Also has Flash Attention v2 dispatching + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + + # Group query attention + if n_groups != 1: + K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + if hidden_states.requires_grad: + K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) + V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) + else: + Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) + pass + A = xformers_attention(Q, K, V, attn_bias = causal_mask) + A = A.view(bsz, q_len, n_heads, head_dim) + + elif HAS_FLASH_ATTENTION and attention_mask is None: + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + A = flash_attn_func(Q, K, V, causal = True) + else: + # Grouped query attention + if n_groups != 1: + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) + pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2).contiguous() + pass + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) + attn_output = self.apply_o(self, attn_output) + attn_weights = None + return attn_output, attn_weights, past_key_value +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590 +def CohereDecoderLayer_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +): + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: + out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0") + + # Self Attention + residual = hidden_states + hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight) + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + # Fully Connected + hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states) + residual += hidden_states_attention + residual += hidden_states_mlp + hidden_states = residual + else: + residual = hidden_states + hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states) + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + hidden_states = residual + hidden_states_attention + hidden_states_mlp + pass + + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + return outputs +pass + + +from math import sqrt as math_sqrt +KV_CACHE_INCREMENT = 256 # KV Cache update size +torch_nn_functional_softmax = torch.nn.functional.softmax +torch_matmul = torch.matmul + +def CohereAttention_fast_forward_inference( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]], + position_ids, + do_prefill = False, + attention_mask = None, +): + Xn = hidden_states + bsz, _, hd = hidden_states.size() + K1, V1 = past_key_value + dtype = Xn.dtype + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + attention_size = n_heads*head_dim + # assert(n_kv_heads * n_groups == n_heads) + seq_len = K1.shape[-2] + kv_seq_len = seq_len + 1 + + # Prefill phase + # if not hasattr(self, "paged_attention"): + if do_prefill: + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) + self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + + # Mistral Nemo 12b has weird dimensions + if attention_size != self.hidden_size: + self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + else: + self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + pass + + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") + self.scalar = 1.0 / math_sqrt(self.head_dim) + self.half_head_dim = head_dim // 2 + # Cohere has QK layernorms + if self.use_qk_norm: + self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0") + else: + self.q_norm_out_weight = None + self.k_norm_out_weight = None + pass + elif kv_seq_len >= self.paged_attention.shape[0]: + self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) + pass + + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) + Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) + Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + if self.use_qk_norm: + Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight) + K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight) + pass + + # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) + # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + h = self.half_head_dim + + RH_Q = self.RH_Q + RH_Q[:,:,:,:h] = Qn[:,:,:,h:] + RH_Q[:,:,:,h:] = Qn[:,:,:,:h] + torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) + Qn *= cos + Qn.addcmul_(RH_Q, sin) + + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + RH_K[:,:,:,:h] = Kn[:,:,:,h:] + RH_K[:,:,:,h:] = Kn[:,:,:,:h] + torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) + Kn *= cos + Kn.addcmul_(RH_K, sin) + + # New KV cache + # Kn = torch.cat([K1, Kn], dim = 2) + # Vn = torch.cat([V1, Vn], dim = 2) + self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) + self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) + Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) + Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) + + # Handle sliding windows + sliding_window = getattr(self.config, "sliding_window", None) + if sliding_window is not None and kv_seq_len > sliding_window: + # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 + slicing_tokens = 1 - sliding_window + Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() + Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() + else: + Knn, Vnn = Kn, Vn + pass + + # Grouped query attention + _, _, cached_len, _ = Knn.shape + if n_groups != 1: + Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) + Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) + pass + # else: + # Knn, Vnn = Knn, Vnn + # pass + + # Attention + if bsz == 1: + Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 + # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows + A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) + A = torch_matmul(A, Vnn, out = Qn) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + pass + A = A.transpose(1, 2) + A = A.reshape(bsz, 1, attention_size) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) + return A, (Kn, Vn) +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +# @torch.inference_mode +def CohereModel_fast_forward_inference( + self, + input_ids, + past_key_values, + position_ids, + attention_mask = None, +): + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + input_ids = input_ids[:,:self.max_seq_length] + hidden_states = self.model.embed_tokens(input_ids) + hidden_states = hidden_states.to(self.config.torch_dtype) + bsz, q_len, hd = hidden_states.shape + seq_len = past_key_values[0][0].shape[-2] + if bsz != 1: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + sliding_window = getattr(self.config, "sliding_window", None), + ) + else: + attention_mask = None + pass + + next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.model.layers): + residual = hidden_states + hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( + decoder_layer.self_attn, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = attention_mask, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + ) + + hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states) + residual += hidden_states_attention + residual += hidden_states_mlp + hidden_states = residual + + next_decoder_cache.append(present_key_value) + pass + hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight) + + return BaseModelOutputWithPast( + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], + ) +pass + + +class FastCohereModel(FastLlamaModel): + + @staticmethod + def pre_patch(): + init_name, function = patch_linear_scaling( + model_name = "cohere", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = CohereAttention, + ) + if init_name is not None: + exec(function, globals()) + CohereAttention.__init__ = eval(init_name) + pass + CohereAttention .forward = CohereAttention_fast_forward + CohereSdpaAttention .forward = CohereAttention_fast_forward + CohereFlashAttention2.forward = CohereAttention_fast_forward + CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward + CohereModel .forward = LlamaModel_fast_forward + CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference) + PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(CohereForCausalLM) + + import transformers.models.cohere.modeling_cohere + transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding + return + pass +pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f62f0f11..5ccf906a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -305,6 +305,20 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight = None): pass +# Normal layernorm with mean removal +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +def fast_layernorm_compiled(layernorm, X): + old_dtype = X.dtype + X = X.float() + mean = X.mean(-1, keepdim = True) + Xbar = X - mean + X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \ + layernorm.variance_epsilon) * \ + layernorm.weight.float() + return X.to(old_dtype) +pass + + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320 def LlamaAttention_fast_forward( self, @@ -495,6 +509,16 @@ def LlamaDecoderLayer_fast_forward( pass +# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452 +__DTYPE_MAP = { + "float32": torch.float32, + torch.float32: torch.float32, + "float16": torch.float16, + torch.float16: torch.float16, + "bfloat16": torch.bfloat16, + torch.bfloat16: torch.bfloat16, +} + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 def LlamaModel_fast_forward( self, @@ -576,11 +600,18 @@ def LlamaModel_fast_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds.to(self.config.torch_dtype) + # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) + torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) + if torch_dtype is not None: + inputs_embeds = inputs_embeds.to(torch_dtype) + else: + raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") + pass # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") IS_GEMMA2 = self.config.model_type.startswith("gemma2") + IS_COHERE = self.config.model_type.startswith("cohere") train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -786,8 +817,11 @@ def custom_forward(*inputs): # Final layernorm if use_cache: - hidden_states = (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ + hidden_states = \ + (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ (self.norm, hidden_states) + elif IS_COHERE: + hidden_states = fast_layernorm_compiled(self.norm, hidden_states) else: hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) pass @@ -877,6 +911,7 @@ def _CausalLM_fast_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -925,6 +960,7 @@ def _CausalLM_fast_forward( loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) + logit_scaling = getattr(self.config, "logit_scale", 0) if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): @@ -937,16 +973,26 @@ def _CausalLM_fast_forward( logits = shift_logits, labels = shift_labels, logit_softcapping = logit_softcapping, + logit_scaling = logit_scaling, ) - elif logit_softcapping != 0: - if logits.requires_grad: - logits = (1.0 / logit_softcapping) * logits - logits = torch.tanh(logits) - logits = logit_softcapping * logits - else: - logits *= (1.0 / logit_softcapping) - torch.tanh(logits, out = logits) - logits *= logit_softcapping + else: + if logit_scaling != 0: + if logits.requires_grad: + logits = logit_scaling * logits + else: + logits *= logit_scaling + pass + pass + if logit_softcapping != 0: + if logits.requires_grad: + logits = (1.0 / logit_softcapping) * logits + logits = torch.tanh(logits) + logits = logit_softcapping * logits + else: + logits *= (1.0 / logit_softcapping) + torch.tanh(logits, out = logits) + logits *= logit_softcapping + pass pass pass @@ -978,6 +1024,7 @@ def PeftModelForCausalLM_fast_forward( output_hidden_states=None, return_dict=None, task_ids=None, + num_logits_to_keep=0, **kwargs, ): return self.base_model( @@ -989,6 +1036,7 @@ def PeftModelForCausalLM_fast_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) pass @@ -2181,6 +2229,7 @@ def patch_peft_model( elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx + elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu else: raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!") pass @@ -2240,6 +2289,14 @@ def patch_peft_model( lora_dropout = model.peft_config[active_adapter].lora_dropout bias = model.peft_config[active_adapter].bias + # We also do not inplace edit QKV for Cohere! + from functools import partial + _apply_lora_mlp = \ + partial(apply_lora_mlp, inplace = False) \ + if model_type == "cohere" else \ + apply_lora_mlp + pass + if lora_dropout == 0 and bias == "none": for idx, layer in enumerate(model.model.model.layers): @@ -2259,7 +2316,7 @@ def patch_peft_model( (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) + layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) n_mlp += 1 else: logger.warning_once( diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e1f17aca..13710eed 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -13,9 +13,10 @@ # limitations under the License. from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING -from .llama import FastLlamaModel, logger +from .llama import FastLlamaModel, logger from .mistral import FastMistralModel -from .qwen2 import FastQwen2Model +from .qwen2 import FastQwen2Model +from .cohere import FastCohereModel from transformers import AutoConfig from transformers import __version__ as transformers_version from peft import PeftConfig, PeftModel @@ -278,6 +279,8 @@ def from_pretrained( dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model + elif model_type == "cohere": + dispatch_model = FastCohereModel else: raise NotImplementedError( f"Unsloth: {model_name} not supported yet!\n"\ diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 3f49c965..bff7f025 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -227,6 +227,7 @@ "meta-llama/Meta-Llama-3.1-8B-Instruct", ), "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B", "meta-llama/Meta-Llama-3.1-70B", ), "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( @@ -236,6 +237,7 @@ "meta-llama/Meta-Llama-3.1-405B-Instruct", ), "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct", ), "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( @@ -253,6 +255,27 @@ "unsloth/Phi-3.5-mini-instruct", "microsoft/Phi-3.5-mini-instruct", ), + "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-08-2024", + ), + "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-plus-08-2024", + ), + "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "unsloth/Llama-3.1-Storm-8B", + "akjindal53244/Llama-3.1-Storm-8B", + ), + "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", + ), + "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-70B", + "NousResearch/Hermes-3-Llama-3.1-70B", + ), + "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( + "NousResearch/Hermes-3-Llama-3.1-405B", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/trainer.py b/unsloth/trainer.py index c8e00be2..45616ca6 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -14,8 +14,13 @@ from dataclasses import dataclass, field from typing import Optional -from transformers import TrainingArguments + from trl import SFTTrainer +try: + from trl import SFTConfig as TrainingArguments +except: + from transformers import TrainingArguments +pass from . import is_bfloat16_supported __all__ = [