Skip to content

Commit

Permalink
Cohere, Bug fixes (#984)
Browse files Browse the repository at this point in the history
* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* update token retrieval logic (#952)

* Fix DPO (#947)

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* update hf token retrieval logic

---------

Co-authored-by: Daniel Han <[email protected]>

* Update llama.py

* get_token

* Update README.md

* Update gemma2.py

* Update rms_layernorm.py

* synchronize

* Update gemma2.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* layernorm

* Update rms_layernorm.py

* Update gemma2.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* revert

* Gemma

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update gemma2.py

* Change UnslothTrainingArguments base class to SFTConfig (#979)

* Cohere

* Update trainer.py

* Cohere

* Cohere

* New models

* Update llama.py

* Update llama.py

* Update cohere.py

* Update llama.py

* Update cohere.py

* retry

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* _apply_lora_mlp

* Update _utils.py

---------

Co-authored-by: Hafedh <[email protected]>
Co-authored-by: Tuan Pham <[email protected]>
  • Loading branch information
3 people authored Sep 3, 2024
1 parent 976d11a commit c085a45
Show file tree
Hide file tree
Showing 9 changed files with 690 additions and 67 deletions.
130 changes: 91 additions & 39 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@
from transformers.models.llama.modeling_llama import logger


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Expand Down Expand Up @@ -62,17 +67,22 @@ def _cross_entropy_forward(

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

if label_idx != -100:
x = tl.load(logits_ptr + label_idx)
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
Expand All @@ -81,18 +91,23 @@ def _cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
Expand Down Expand Up @@ -130,8 +145,11 @@ def _chunked_cross_entropy_forward(

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
Expand All @@ -142,8 +160,10 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
Expand All @@ -153,17 +173,22 @@ def _chunked_cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
Expand Down Expand Up @@ -195,6 +220,13 @@ def _cross_entropy_backward(
dloss = 0.0

x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))

# Do logit scaling for Cohere
if DO_LOGIT_SCALING:
# d/dx [s * x] = s
x = x * LOGIT_SCALE
pass

# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
Expand All @@ -210,6 +242,11 @@ def _cross_entropy_backward(
y, # exp(x - logsumexp)
)

if DO_LOGIT_SCALING:
# d/dx [s * x] = s
y = y * LOGIT_SCALE
pass

if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
Expand All @@ -224,14 +261,15 @@ def _cross_entropy_backward(

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_softcapping = 0):
def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0):
n_rows, vocab_size = logits.shape

div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

DO_SOFTCAPPING = (logit_softcapping != 0)
DO_SOFTCAPPING = (logit_softcapping != 0)
DO_LOGIT_SCALING = (logit_scaling != 0)

if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
Expand All @@ -243,11 +281,13 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
Expand All @@ -258,12 +298,14 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
LOGIT_SCALE = logit_scaling,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
Expand All @@ -275,6 +317,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0):
ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
ctx.logit_scaling = logit_scaling
return losses
pass

Expand All @@ -292,19 +336,26 @@ def backward(ctx, dlosses):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
LOGIT_SCALE = ctx.logit_scaling,
num_warps = 8,
)
return logits, None, None,
return logits, None, None, None,
pass
pass


@torch._disable_dynamo
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
def fast_cross_entropy_loss(
logits,
labels,
logit_softcapping = 0,
logit_scaling = 0,
):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
Expand All @@ -319,6 +370,7 @@ def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
logit_scaling,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
Expand Down
30 changes: 19 additions & 11 deletions unsloth/kernels/fast_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
_forward_function, _backward_function,):
_forward_function, _backward_function,
inplace = True,):
dtype = X.dtype

e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
Expand All @@ -84,6 +85,7 @@ def forward(ctx, X : torch.Tensor,
)
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
X, e, g)
ctx.inplace = inplace
return i
pass

Expand Down Expand Up @@ -131,7 +133,7 @@ def backward(ctx, dY : torch.Tensor):
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())

Expand All @@ -147,35 +149,37 @@ def backward(ctx, dY : torch.Tensor):
None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, d_upA.t(), d_upB.t(), None, \
None, None, d_downA.t(), d_downB.t(), None, \
None, None, # _backward and _forward
None, None, None, # _backward and _forward and inplace
pass
pass


from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X):
def apply_lora_mlp_swiglu(self, X, inplace = True):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,)
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
inplace,)
return out
pass


from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X):
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
geglu_exact_forward_kernel, geglu_exact_backward_kernel,)
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
inplace,)
return out
pass

Expand Down Expand Up @@ -229,7 +233,8 @@ class LoRA_QKV(torch.autograd.Function):
def forward(ctx, X : torch.Tensor,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,):
VW, VW_quant, VA, VB, VS,
inplace = True):
dtype = X.dtype

Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
Expand All @@ -242,6 +247,7 @@ def forward(ctx, X : torch.Tensor,
VW, VW_quant, VS,
)
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
ctx.inplace = inplace
return Q, K, V
pass

Expand Down Expand Up @@ -286,7 +292,7 @@ def backward(ctx, dQ, dK, dV):
# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

Expand All @@ -308,19 +314,21 @@ def backward(ctx, dQ, dK, dV):
return dX.view(batch, seq_len, hd), \
None, None, d_QA.t(), d_QB.t(), None, \
None, None, d_KA.t(), d_KB.t(), None, \
None, None, d_VA.t(), d_VB.t(), None
None, None, d_VA.t(), d_VB.t(), None, \
None,
pass
pass


def apply_lora_qkv(self, X):
def apply_lora_qkv(self, X, inplace = True):
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
inplace,
)
return Q, K, V
pass
Expand Down
Loading

0 comments on commit c085a45

Please sign in to comment.