Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix #1249

Merged
merged 175 commits into from
Nov 6, 2024
Merged

Bug fix #1249

Changes from all commits
Commits
Show all changes
175 commits
Select commit Hold shift + click to select a range
f0aca90
Fix TRL
danielhanchen Oct 21, 2024
f4ae585
Update mistral.py
danielhanchen Oct 22, 2024
106f213
Patch processing_class
danielhanchen Oct 22, 2024
ef84212
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
4f7c527
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aa2b207
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
101389d
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
c0f0fc9
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
b3e0033
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aabb5ff
Installation guide (#1165)
timothelaborie Oct 23, 2024
30bf339
chore: update chat_templates.py (#1166)
eltociear Oct 23, 2024
2895839
Disable Flex Attention
danielhanchen Oct 23, 2024
06f5d75
Update tokenizer_utils.py
danielhanchen Oct 23, 2024
28e6eea
Update _utils.py
danielhanchen Oct 23, 2024
b821f20
n_items
danielhanchen Oct 24, 2024
e561366
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
4ff247a
Fix DPO, ORPO
danielhanchen Oct 24, 2024
2b858a5
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
1c063b4
Update _utils.py
danielhanchen Oct 24, 2024
f195ee1
Update _utils.py
danielhanchen Oct 24, 2024
faf2747
fix/transformers-unpack (#1180)
Erland366 Oct 24, 2024
5961c34
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
7308bb8
Update _utils.py
danielhanchen Oct 24, 2024
0096e5b
Update _utils.py
danielhanchen Oct 24, 2024
44b480f
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
6776055
donot upcast lm_head and embeddings to float32 (#1186)
Datta0 Oct 25, 2024
625209e
Cleanup upcast logs (#1188)
Datta0 Oct 25, 2024
2bc189f
Fix/phi-longrope (#1193)
Erland366 Oct 25, 2024
6f28d16
Update transformers
danielhanchen Oct 26, 2024
f94f7c1
Merge branch 'main' into nightly
danielhanchen Oct 26, 2024
bf3b175
Merge branch 'main' into nightly
danielhanchen Oct 27, 2024
7083a1d
Unk token issues
danielhanchen Oct 28, 2024
3acc5af
Update _utils.py
danielhanchen Oct 28, 2024
1c044da
Fix pad token
danielhanchen Oct 28, 2024
5286f19
Update llama.py
danielhanchen Oct 28, 2024
02437a8
Typo
danielhanchen Oct 28, 2024
9d07be0
ignored labels
danielhanchen Oct 28, 2024
a8b37a3
Revert "ignored labels"
danielhanchen Oct 28, 2024
2dfdba3
More patching
danielhanchen Oct 28, 2024
5541ab4
Update _utils.py
danielhanchen Oct 28, 2024
c6e9af2
Update _utils.py
danielhanchen Oct 28, 2024
cac56d1
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
5ee1189
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
85a5f60
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
20e38ed
Feat/all tmp (#1219)
danielhanchen Oct 30, 2024
7e1692a
Bug fixes
danielhanchen Oct 30, 2024
6bef8f1
Update pyproject.toml
danielhanchen Oct 30, 2024
9ccbc0e
Update _utils.py
danielhanchen Oct 30, 2024
95ecc57
Update __init__.py
danielhanchen Oct 30, 2024
5f5fef8
Update __init__.py
danielhanchen Oct 30, 2024
784dd13
Update _utils.py
danielhanchen Oct 30, 2024
5b75e21
Update _utils.py
danielhanchen Oct 30, 2024
74ab93c
Update _utils.py
danielhanchen Oct 30, 2024
526505c
Update _utils.py
danielhanchen Oct 30, 2024
251ba77
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
530c495
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
07394c3
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
6d7004b
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
d86b20a
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9920950
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9f926ce
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
30cdf65
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
54b901b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
6db9d28
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8aefcd0
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
7bf626b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
d455751
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
055eeb8
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8090b7c
Tied weights
danielhanchen Oct 31, 2024
7559efb
Revert "Tied weights"
danielhanchen Oct 31, 2024
ad63a32
Tied weights
danielhanchen Oct 31, 2024
35aa992
Utils
danielhanchen Nov 3, 2024
0172ee3
CE Loss patching
danielhanchen Nov 3, 2024
c228682
Update __init__.py
danielhanchen Nov 3, 2024
9aa221a
Update __init__.py
danielhanchen Nov 3, 2024
751413e
Patching
danielhanchen Nov 3, 2024
82db087
Update cross_entropy_loss.py
danielhanchen Nov 3, 2024
cf68202
CE Loss
danielhanchen Nov 3, 2024
63a1828
Update _utils.py
danielhanchen Nov 3, 2024
3f0e56f
Update _utils.py
danielhanchen Nov 3, 2024
1190ed4
CE Loss
danielhanchen Nov 3, 2024
607ac34
Update _utils.py
danielhanchen Nov 3, 2024
32eac0b
Update _utils.py
danielhanchen Nov 3, 2024
5b6d401
Layernorm
danielhanchen Nov 4, 2024
3d19a71
Update _utils.py
danielhanchen Nov 4, 2024
76da511
Update _utils.py
danielhanchen Nov 4, 2024
013ebaa
Post patch
danielhanchen Nov 4, 2024
608916a
Update _utils.py
danielhanchen Nov 4, 2024
19836e3
Update llama.py
danielhanchen Nov 4, 2024
0164087
Update _utils.py
danielhanchen Nov 4, 2024
205f7ad
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2f1f393
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
05b8f66
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8d205c0
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
a1e9e13
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
94655f8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
085f998
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c796fd9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
e943d77
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
16a7df6
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f65b064
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
1ff49b8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
080e558
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f6d50c7
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
fad4202
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
736b16a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
eb76416
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
367e43f
typing
danielhanchen Nov 4, 2024
993df20
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8f566b3
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
22bb46b
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
b5c9f81
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c7b2220
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2d0ab26
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
428f662
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5023ce9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5ca3d4a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
3b32d81
int64
danielhanchen Nov 4, 2024
9bae6e2
Update _utils.py
danielhanchen Nov 4, 2024
5123623
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b1d9e2
constexpr
danielhanchen Nov 4, 2024
7d5111a
constexpr
danielhanchen Nov 4, 2024
dff5a52
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
969d1bd
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b5847f
Update _utils.py
danielhanchen Nov 4, 2024
766bf1e
Update _utils.py
danielhanchen Nov 4, 2024
646f1b7
Update _utils.py
danielhanchen Nov 5, 2024
97f37ac
CE
danielhanchen Nov 5, 2024
cc563fa
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
f643148
Update _utils.py
danielhanchen Nov 5, 2024
f28d7f6
Update llama.py
danielhanchen Nov 5, 2024
d8103e1
Update _utils.py
danielhanchen Nov 5, 2024
b9e1a49
Update rms_layernorm.py
danielhanchen Nov 5, 2024
56af302
Update rms_layernorm.py
danielhanchen Nov 5, 2024
a3c84a3
Update rms_layernorm.py
danielhanchen Nov 5, 2024
f7d5c56
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8496ff6
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2909eaf
Update rms_layernorm.py
danielhanchen Nov 5, 2024
afc8af6
Update utils.py
danielhanchen Nov 5, 2024
2d8d1e1
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ecc1ad2
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ae7cb78
Update rms_layernorm.py
danielhanchen Nov 5, 2024
22da266
Update rms_layernorm.py
danielhanchen Nov 5, 2024
beb6854
Update rms_layernorm.py
danielhanchen Nov 5, 2024
14c3d2f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef4b079
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef684f8
Update rms_layernorm.py
danielhanchen Nov 5, 2024
3e4c42f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8f825eb
Update rms_layernorm.py
danielhanchen Nov 5, 2024
bd4ac7b
Update rms_layernorm.py
danielhanchen Nov 5, 2024
6f38731
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2df35d4
typing
danielhanchen Nov 5, 2024
74d89d1
Update rope_embedding.py
danielhanchen Nov 5, 2024
98927ee
types
danielhanchen Nov 5, 2024
f3e2bd6
Disable compiling
danielhanchen Nov 5, 2024
c30bd2a
Update _utils.py
danielhanchen Nov 5, 2024
813cbdd
Update _utils.py
danielhanchen Nov 5, 2024
34ce5d1
Forward hook
danielhanchen Nov 5, 2024
f84cf4b
Update _utils.py
danielhanchen Nov 5, 2024
745814c
Update llama.py
danielhanchen Nov 5, 2024
ab9f8e1
Update _utils.py
danielhanchen Nov 5, 2024
daa7909
Update llama.py
danielhanchen Nov 5, 2024
536a1a6
Update llama.py
danielhanchen Nov 5, 2024
648ca59
Update _utils.py
danielhanchen Nov 5, 2024
486d0d6
Update pyproject.toml
danielhanchen Nov 5, 2024
eb4da9d
Update _utils.py
danielhanchen Nov 5, 2024
da397f4
Update llama.py
danielhanchen Nov 5, 2024
70b65cf
CE Loss
danielhanchen Nov 5, 2024
aeec57e
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
fb393fc
Update _utils.py
danielhanchen Nov 5, 2024
cab1e72
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
51fea97
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
58e541b
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
0ed0532
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
ef2c56f
Update llama.py
danielhanchen Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype)

logits = logits.to(tl.float32)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

Expand Down Expand Up @@ -152,14 +151,13 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)

# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

Expand Down Expand Up @@ -229,7 +227,7 @@ def _cross_entropy_backward(
else:
dloss = 0.0

x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)

# Do logit scaling for Cohere
if DO_LOGIT_SCALING:
Expand All @@ -241,12 +239,12 @@ def _cross_entropy_backward(
partial = x
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = triton_tanh(x.to(tl.float32) / SOFTCAP).to(x.dtype)
partial = triton_tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass

logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.exp(x - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
Expand Down Expand Up @@ -337,6 +335,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
return losses
pass


@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
Expand All @@ -345,6 +344,8 @@ def backward(ctx, dlosses):
n_rows, vocab_size = logits.shape

BLOCK_SIZE : int = 4096
div : int
mod : int
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks : int = div + (mod != 0)

Expand Down