diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..4d9b7498
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,162 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/tests/fused_cel/README.md b/tests/fused_cel/README.md
new file mode 100644
index 00000000..14386952
--- /dev/null
+++ b/tests/fused_cel/README.md
@@ -0,0 +1,425 @@
+## Efficient Fused Cross Entropy Loss
+
+Memory-efficient cross entropy implementation that only materializes the derivatives of the language modeling head layer without storing the logits and chunks the computation of the logits such that the full logits tensor is never realized.
+
+This is a direct adaptation of this [repo](https://github.com/mgmalek/efficient_cross_entropy/tree/main).
+
+## Contents
+
+- [Overview](#overview)
+- [Changes](#changes)
+- [Tests](#tests)
+- [Benchmarks](#benchmarks)
+- [Profiling](#profiling)
+- [Next Steps](#next-steps)
+
+## Overview
+
+In short:
+
+- the logits, derivative with respect to the hidden state inputs to the language modeling head layer (`dX` hereafter), and the derivative with respect to the logits projection weights (`dW` hereafter) are computed in chunks
+- the logits are overwritten by its derivatives within a custom loss kernel to avoid additional memory allocations.
+
+See the original [repo](https://github.com/mgmalek/efficient_cross_entropy/tree/main) for an excellent explanation of the design.
+
+## Changes
+
+The following changes were made to the original kernel:
+
+- Reshape inputs and labels to adapt the `3-D` language modeling tensors with the required shapes of the kernel.
+- Upcast `loss` to `float32`, which in the original kernel was initialized to the autocasted / in-feat dtype.
+- Add `torch.cuda.amp.{custom_fwd,custom_bwd}` to the `autograd.Function`.
+
+All changes are enumerated in `unsloth/kernels/fused_cel.py`.
+
+Additionally, adapter layers and configs in `fused_cel.py` enable integration with `transformers` and `unsloth`.
+
+## Tests
+
+See `tests/test_CEL.py` for correctness checks.
+
+The comments in the tests describe numerical edge cases.
+
+## Benchmarks
+
+Following are results from preliminary benchmarking / testing on a `L4` NVIDIA GPU for a small `llama-like` [model](https://huggingface.co/hf-internal-testing/tiny-random-LlamaForCausalLM) with and without the `fused CEL` layer.
+
+The takeaway is that the memory efficiency claims of the original `repo` are evident, with overall memory usage lower, decreasing linearly with the number of loop iterations.
+
+Can be reproduced by passing the provided options to `benchmark_hf_test_cel.py` (run with `--help` to see all options).
+
+Below is the overall config, followed by `training losses` / `grad norms` and overall `training metrics` for `float32` and `bfloat16`.
+
+`Test config`:
+
+- `max_steps=50`
+- `model_id=hf-internal-testing/tiny-random-LlamaForCausalLM`
+- `batch_size=2`
+- `max_seq_len=256`
+- `packing=True`
+- `grad_accum_steps=1`
+- `load_in_4bit=False`
+- `use_lora=False`
+- `fused_cel_n_loop_iters=[1, 2, 4]`
+
+`float32`
+
+- _n_loop_it=1_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375981 | 0.375981 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.409343 | 0.409344 | 0.000000 |
+| 3 | 10.374800 | 10.374800 | 0.000000 | 0.411205 | 0.411205 | 0.000000 |
+| 4 | 10.380000 | 10.380000 | 0.000000 | 0.337345 | 0.337345 | 0.000000 |
+| 5 | 10.376800 | 10.376800 | 0.000000 | 0.354001 | 0.354001 | 0.000000 |
+| 6 | 10.363800 | 10.363800 | 0.000000 | 0.457850 | 0.457851 | 0.000000 |
+| 7 | 10.379100 | 10.379100 | 0.000000 | 0.327099 | 0.327099 | 0.000000 |
+| 8 | 10.372200 | 10.372200 | 0.000000 | 0.324939 | 0.324939 | 0.000000 |
+| 9 | 10.360500 | 10.360500 | 0.000000 | 0.463365 | 0.463365 | 0.000000 |
+| 10 | 10.369700 | 10.369700 | 0.000000 | 0.345713 | 0.345714 | 0.000000 |
+| 11 | 10.377000 | 10.377000 | 0.000000 | 0.323786 | 0.323786 | 0.000000 |
+| 12 | 10.363000 | 10.363000 | 0.000000 | 0.366833 | 0.366833 | 0.000000 |
+| 13 | 10.358700 | 10.358700 | 0.000000 | 0.386118 | 0.386118 | 0.000000 |
+| 14 | 10.362500 | 10.362500 | 0.000000 | 0.345925 | 0.345925 | 0.000000 |
+| 15 | 10.368100 | 10.368100 | 0.000000 | 0.339570 | 0.339571 | 0.000000 |
+| 16 | 10.360500 | 10.360500 | 0.000000 | 0.382450 | 0.382450 | 0.000000 |
+| 17 | 10.367800 | 10.367800 | 0.000000 | 0.328462 | 0.328463 | 0.000000 |
+| 18 | 10.362700 | 10.362700 | 0.000000 | 0.567761 | 0.567761 | 0.000000 |
+| 19 | 10.359300 | 10.359300 | 0.000000 | 0.344158 | 0.344158 | 0.000000 |
+| 20 | 10.363500 | 10.363500 | 0.000000 | 0.337636 | 0.337636 | 0.000000 |
+| 21 | 10.352300 | 10.352300 | 0.000000 | 0.382984 | 0.382984 | 0.000000 |
+| 22 | 10.364700 | 10.364700 | 0.000000 | 0.330023 | 0.330023 | 0.000000 |
+| 23 | 10.365200 | 10.365200 | 0.000000 | 0.366450 | 0.366450 | 0.000000 |
+| 24 | 10.351900 | 10.351900 | 0.000000 | 0.366239 | 0.366240 | 0.000000 |
+| 25 | 10.345900 | 10.345900 | 0.000000 | 0.454505 | 0.454506 | 0.000000 |
+| 26 | 10.353900 | 10.353900 | 0.000000 | 0.372731 | 0.372731 | 0.000000 |
+| 27 | 10.351000 | 10.351000 | 0.000000 | 0.386128 | 0.386128 | 0.000000 |
+| 28 | 10.362900 | 10.362900 | 0.000000 | 0.362428 | 0.362428 | 0.000000 |
+| 29 | 10.356200 | 10.356200 | 0.000000 | 0.362041 | 0.362041 | 0.000000 |
+| 30 | 10.361400 | 10.361400 | 0.000000 | 0.345147 | 0.345147 | 0.000000 |
+| 31 | 10.357700 | 10.357700 | 0.000000 | 0.353345 | 0.353345 | 0.000000 |
+| 32 | 10.358000 | 10.358000 | 0.000000 | 0.338220 | 0.338219 | 0.000001 |
+| 33 | 10.357200 | 10.357200 | 0.000000 | 0.346525 | 0.346525 | 0.000000 |
+| 34 | 10.338500 | 10.338500 | 0.000000 | 0.429826 | 0.429826 | 0.000001 |
+| 35 | 10.338200 | 10.338200 | 0.000000 | 0.410369 | 0.410370 | 0.000000 |
+| 36 | 10.362200 | 10.362200 | 0.000000 | 0.308196 | 0.308197 | 0.000001 |
+| 37 | 10.338700 | 10.338700 | 0.000000 | 0.406986 | 0.406987 | 0.000001 |
+| 38 | 10.355800 | 10.355800 | 0.000000 | 0.347940 | 0.347942 | 0.000002 |
+| 39 | 10.337200 | 10.337200 | 0.000000 | 0.484625 | 0.484626 | 0.000001 |
+| 40 | 10.355100 | 10.355100 | 0.000000 | 0.419877 | 0.419879 | 0.000002 |
+| 41 | 10.357300 | 10.357300 | 0.000000 | 0.355641 | 0.355643 | 0.000001 |
+| 42 | 10.361700 | 10.361700 | 0.000000 | 0.338817 | 0.338817 | 0.000001 |
+| 43 | 10.327000 | 10.327000 | 0.000000 | 0.466670 | 0.466672 | 0.000001 |
+| 44 | 10.351100 | 10.351100 | 0.000000 | 0.365030 | 0.365031 | 0.000001 |
+| 45 | 10.360800 | 10.360800 | 0.000000 | 0.347445 | 0.347447 | 0.000001 |
+| 46 | 10.315900 | 10.315900 | 0.000000 | 0.495173 | 0.495069 | 0.000104 |
+| 47 | 10.345500 | 10.345500 | 0.000000 | 0.373585 | 0.373586 | 0.000001 |
+| 48 | 10.339500 | 10.339500 | 0.000000 | 0.367941 | 0.367942 | 0.000001 |
+| 49 | 10.318600 | 10.318600 | 0.000000 | 0.495867 | 0.495869 | 0.000001 |
+| 50 | 10.368600 | 10.368600 | 0.000000 | 0.427715 | 0.427713 | 0.000001 |
+
+- _n_loop_it=2_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375981 | 0.375981 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.409343 | 0.409344 | 0.000000 |
+| 3 | 10.374800 | 10.374800 | 0.000000 | 0.411205 | 0.411205 | 0.000000 |
+| 4 | 10.380000 | 10.380000 | 0.000000 | 0.337345 | 0.337345 | 0.000000 |
+| 5 | 10.376800 | 10.376800 | 0.000000 | 0.354001 | 0.354001 | 0.000000 |
+| 6 | 10.363800 | 10.363800 | 0.000000 | 0.457850 | 0.457851 | 0.000000 |
+| 7 | 10.379100 | 10.379100 | 0.000000 | 0.327099 | 0.327099 | 0.000000 |
+| 8 | 10.372200 | 10.372200 | 0.000000 | 0.324939 | 0.324939 | 0.000000 |
+| 9 | 10.360500 | 10.360500 | 0.000000 | 0.463365 | 0.463365 | 0.000000 |
+| 10 | 10.369700 | 10.369700 | 0.000000 | 0.345713 | 0.345714 | 0.000000 |
+| 11 | 10.377000 | 10.377000 | 0.000000 | 0.323786 | 0.323786 | 0.000000 |
+| 12 | 10.363000 | 10.363000 | 0.000000 | 0.366833 | 0.366833 | 0.000000 |
+| 13 | 10.358700 | 10.358700 | 0.000000 | 0.386118 | 0.386118 | 0.000000 |
+| 14 | 10.362500 | 10.362500 | 0.000000 | 0.345925 | 0.345925 | 0.000000 |
+| 15 | 10.368100 | 10.368100 | 0.000000 | 0.339570 | 0.339571 | 0.000000 |
+| 16 | 10.360500 | 10.360500 | 0.000000 | 0.382450 | 0.382450 | 0.000000 |
+| 17 | 10.367800 | 10.367800 | 0.000000 | 0.328462 | 0.328463 | 0.000000 |
+| 18 | 10.362700 | 10.362700 | 0.000000 | 0.567761 | 0.567761 | 0.000000 |
+| 19 | 10.359300 | 10.359300 | 0.000000 | 0.344158 | 0.344158 | 0.000000 |
+| 20 | 10.363500 | 10.363500 | 0.000000 | 0.337636 | 0.337636 | 0.000001 |
+| 21 | 10.352300 | 10.352300 | 0.000000 | 0.382984 | 0.382984 | 0.000000 |
+| 22 | 10.364700 | 10.364700 | 0.000000 | 0.330023 | 0.330023 | 0.000000 |
+| 23 | 10.365200 | 10.365200 | 0.000000 | 0.366450 | 0.366450 | 0.000000 |
+| 24 | 10.351900 | 10.351900 | 0.000000 | 0.366239 | 0.366240 | 0.000000 |
+| 25 | 10.345900 | 10.345900 | 0.000000 | 0.454505 | 0.454506 | 0.000000 |
+| 26 | 10.353900 | 10.353900 | 0.000000 | 0.372731 | 0.372731 | 0.000000 |
+| 27 | 10.351000 | 10.351000 | 0.000000 | 0.386128 | 0.386128 | 0.000000 |
+| 28 | 10.362900 | 10.362900 | 0.000000 | 0.362428 | 0.362428 | 0.000000 |
+| 29 | 10.356200 | 10.356200 | 0.000000 | 0.362041 | 0.362041 | 0.000000 |
+| 30 | 10.361400 | 10.361400 | 0.000000 | 0.345147 | 0.345147 | 0.000000 |
+| 31 | 10.357700 | 10.357700 | 0.000000 | 0.353345 | 0.353345 | 0.000000 |
+| 32 | 10.358000 | 10.358000 | 0.000000 | 0.338220 | 0.338219 | 0.000001 |
+| 33 | 10.357200 | 10.357200 | 0.000000 | 0.346525 | 0.346525 | 0.000000 |
+| 34 | 10.338500 | 10.338500 | 0.000000 | 0.429826 | 0.429826 | 0.000000 |
+| 35 | 10.338200 | 10.338200 | 0.000000 | 0.410370 | 0.410370 | 0.000000 |
+| 36 | 10.362200 | 10.362200 | 0.000000 | 0.308196 | 0.308197 | 0.000000 |
+| 37 | 10.338700 | 10.338700 | 0.000000 | 0.406987 | 0.406987 | 0.000000 |
+| 38 | 10.355800 | 10.355800 | 0.000000 | 0.347942 | 0.347942 | 0.000000 |
+| 39 | 10.337200 | 10.337200 | 0.000000 | 0.484625 | 0.484626 | 0.000000 |
+| 40 | 10.355100 | 10.355100 | 0.000000 | 0.419878 | 0.419879 | 0.000000 |
+| 41 | 10.357300 | 10.357300 | 0.000000 | 0.355642 | 0.355643 | 0.000001 |
+| 42 | 10.361700 | 10.361700 | 0.000000 | 0.338817 | 0.338817 | 0.000000 |
+| 43 | 10.327000 | 10.327000 | 0.000000 | 0.466671 | 0.466672 | 0.000000 |
+| 44 | 10.351100 | 10.351100 | 0.000000 | 0.365031 | 0.365031 | 0.000000 |
+| 45 | 10.360800 | 10.360800 | 0.000000 | 0.347446 | 0.347447 | 0.000001 |
+| 46 | 10.315900 | 10.315900 | 0.000000 | 0.495084 | 0.495069 | 0.000015 |
+| 47 | 10.345500 | 10.345500 | 0.000000 | 0.373585 | 0.373586 | 0.000001 |
+| 48 | 10.339500 | 10.339500 | 0.000000 | 0.367942 | 0.367942 | 0.000000 |
+| 49 | 10.318600 | 10.318600 | 0.000000 | 0.495868 | 0.495869 | 0.000000 |
+| 50 | 10.368600 | 10.368600 | 0.000000 | 0.427714 | 0.427713 | 0.000001 |
+
+- _n_loop_it=4_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375981 | 0.375981 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.409343 | 0.409344 | 0.000000 |
+| 3 | 10.374800 | 10.374800 | 0.000000 | 0.411205 | 0.411205 | 0.000000 |
+| 4 | 10.380000 | 10.380000 | 0.000000 | 0.337345 | 0.337345 | 0.000000 |
+| 5 | 10.376800 | 10.376800 | 0.000000 | 0.354001 | 0.354001 | 0.000000 |
+| 6 | 10.363800 | 10.363800 | 0.000000 | 0.457850 | 0.457851 | 0.000000 |
+| 7 | 10.379100 | 10.379100 | 0.000000 | 0.327099 | 0.327099 | 0.000000 |
+| 8 | 10.372200 | 10.372200 | 0.000000 | 0.324939 | 0.324939 | 0.000000 |
+| 9 | 10.360500 | 10.360500 | 0.000000 | 0.463365 | 0.463365 | 0.000000 |
+| 10 | 10.369700 | 10.369700 | 0.000000 | 0.345713 | 0.345714 | 0.000000 |
+| 11 | 10.377000 | 10.377000 | 0.000000 | 0.323786 | 0.323786 | 0.000000 |
+| 12 | 10.363000 | 10.363000 | 0.000000 | 0.366833 | 0.366833 | 0.000000 |
+| 13 | 10.358700 | 10.358700 | 0.000000 | 0.386118 | 0.386118 | 0.000000 |
+| 14 | 10.362500 | 10.362500 | 0.000000 | 0.345925 | 0.345925 | 0.000000 |
+| 15 | 10.368100 | 10.368100 | 0.000000 | 0.339570 | 0.339571 | 0.000000 |
+| 16 | 10.360500 | 10.360500 | 0.000000 | 0.382450 | 0.382450 | 0.000000 |
+| 17 | 10.367800 | 10.367800 | 0.000000 | 0.328462 | 0.328463 | 0.000000 |
+| 18 | 10.362700 | 10.362700 | 0.000000 | 0.567761 | 0.567761 | 0.000000 |
+| 19 | 10.359300 | 10.359300 | 0.000000 | 0.344158 | 0.344158 | 0.000000 |
+| 20 | 10.363500 | 10.363500 | 0.000000 | 0.337636 | 0.337636 | 0.000001 |
+| 21 | 10.352300 | 10.352300 | 0.000000 | 0.382984 | 0.382984 | 0.000000 |
+| 22 | 10.364700 | 10.364700 | 0.000000 | 0.330023 | 0.330023 | 0.000000 |
+| 23 | 10.365200 | 10.365200 | 0.000000 | 0.366450 | 0.366450 | 0.000000 |
+| 24 | 10.351900 | 10.351900 | 0.000000 | 0.366239 | 0.366240 | 0.000000 |
+| 25 | 10.345900 | 10.345900 | 0.000000 | 0.454506 | 0.454506 | 0.000000 |
+| 26 | 10.353900 | 10.353900 | 0.000000 | 0.372731 | 0.372731 | 0.000000 |
+| 27 | 10.351000 | 10.351000 | 0.000000 | 0.386128 | 0.386128 | 0.000000 |
+| 28 | 10.362900 | 10.362900 | 0.000000 | 0.362428 | 0.362428 | 0.000000 |
+| 29 | 10.356200 | 10.356200 | 0.000000 | 0.362041 | 0.362041 | 0.000000 |
+| 30 | 10.361400 | 10.361400 | 0.000000 | 0.345147 | 0.345147 | 0.000000 |
+| 31 | 10.357700 | 10.357700 | 0.000000 | 0.353345 | 0.353345 | 0.000000 |
+| 32 | 10.358000 | 10.358000 | 0.000000 | 0.338220 | 0.338219 | 0.000001 |
+| 33 | 10.357200 | 10.357200 | 0.000000 | 0.346525 | 0.346525 | 0.000000 |
+| 34 | 10.338500 | 10.338500 | 0.000000 | 0.429826 | 0.429826 | 0.000000 |
+| 35 | 10.338200 | 10.338200 | 0.000000 | 0.410370 | 0.410370 | 0.000001 |
+| 36 | 10.362200 | 10.362200 | 0.000000 | 0.308197 | 0.308197 | 0.000000 |
+| 37 | 10.338700 | 10.338700 | 0.000000 | 0.406987 | 0.406987 | 0.000000 |
+| 38 | 10.355800 | 10.355800 | 0.000000 | 0.347942 | 0.347942 | 0.000000 |
+| 39 | 10.337200 | 10.337200 | 0.000000 | 0.484626 | 0.484626 | 0.000001 |
+| 40 | 10.355100 | 10.355100 | 0.000000 | 0.419879 | 0.419879 | 0.000000 |
+| 41 | 10.357300 | 10.357300 | 0.000000 | 0.355643 | 0.355643 | 0.000000 |
+| 42 | 10.361700 | 10.361700 | 0.000000 | 0.338818 | 0.338817 | 0.000000 |
+| 43 | 10.327000 | 10.327000 | 0.000000 | 0.466672 | 0.466672 | 0.000000 |
+| 44 | 10.351100 | 10.351100 | 0.000000 | 0.365031 | 0.365031 | 0.000000 |
+| 45 | 10.360800 | 10.360800 | 0.000000 | 0.347446 | 0.347447 | 0.000001 |
+| 46 | 10.315900 | 10.315900 | 0.000000 | 0.495063 | 0.495069 | 0.000006 |
+| 47 | 10.345500 | 10.345500 | 0.000000 | 0.373586 | 0.373586 | 0.000000 |
+| 48 | 10.339500 | 10.339500 | 0.000000 | 0.367942 | 0.367942 | 0.000000 |
+| 49 | 10.318600 | 10.318600 | 0.000000 | 0.495869 | 0.495869 | 0.000000 |
+| 50 | 10.368600 | 10.368600 | 0.000000 | 0.427715 | 0.427713 | 0.000001 |
+
+`Training metrics` for `float32`:
+
+| | step | trainable_params | total_params | n_loop_iters | total_flos | train_loss | train_mem_gpu_peaked_delta | train_samples_per_second | train_steps_per_second | train_runtime |
+| --------- | ---- | ---------------- | ------------ | ------------ | ---------- | ---------- | -------------------------- | ------------------------ | ---------------------- | ------------- |
+| no-fused | 50 | 1032272 | 1032272 | 1 | 74GF | 10.3577 | 188MB | 27.031 | 13.516 | 0:00:03.69 |
+| fused_cel | 50 | 1032272 | 1032272 | 1 | 74GF | 10.3577 | 66MB | 27.321 | 13.66 | 0:00:03.66 |
+| fused_cel | 50 | 1032272 | 1032272 | 2 | 74GF | 10.3577 | 35MB | 34.413 | 17.207 | 0:00:02.90 |
+| fused_cel | 50 | 1032272 | 1032272 | 4 | 74GF | 10.3577 | 19MB | 34.124 | 17.062 | 0:00:02.93 |
+
+`bfloat16`
+
+- _n_loop_it=1_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375000 | 0.375000 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.408203 | 0.408203 | 0.000000 |
+| 3 | 10.374700 | 10.374800 | 0.000100 | 0.408203 | 0.408203 | 0.000000 |
+| 4 | 10.379900 | 10.379900 | 0.000000 | 0.335938 | 0.335938 | 0.000000 |
+| 5 | 10.376600 | 10.376600 | 0.000000 | 0.353516 | 0.353516 | 0.000000 |
+| 6 | 10.363300 | 10.363300 | 0.000000 | 0.457031 | 0.457031 | 0.000000 |
+| 7 | 10.378900 | 10.378900 | 0.000000 | 0.326172 | 0.326172 | 0.000000 |
+| 8 | 10.372000 | 10.372000 | 0.000000 | 0.324219 | 0.324219 | 0.000000 |
+| 9 | 10.360000 | 10.360000 | 0.000000 | 0.460938 | 0.460938 | 0.000000 |
+| 10 | 10.369300 | 10.369300 | 0.000000 | 0.343750 | 0.343750 | 0.000000 |
+| 11 | 10.377000 | 10.377000 | 0.000000 | 0.322266 | 0.322266 | 0.000000 |
+| 12 | 10.362600 | 10.362600 | 0.000000 | 0.365234 | 0.365234 | 0.000000 |
+| 13 | 10.358700 | 10.358700 | 0.000000 | 0.384766 | 0.384766 | 0.000000 |
+| 14 | 10.362900 | 10.362900 | 0.000000 | 0.345703 | 0.345703 | 0.000000 |
+| 15 | 10.368100 | 10.368100 | 0.000000 | 0.337891 | 0.337891 | 0.000000 |
+| 16 | 10.360100 | 10.360100 | 0.000000 | 0.378906 | 0.378906 | 0.000000 |
+| 17 | 10.367600 | 10.367700 | 0.000100 | 0.326172 | 0.326172 | 0.000000 |
+| 18 | 10.362000 | 10.362100 | 0.000100 | 0.566406 | 0.566406 | 0.000000 |
+| 19 | 10.359200 | 10.359100 | 0.000100 | 0.345703 | 0.345703 | 0.000000 |
+| 20 | 10.362900 | 10.362900 | 0.000000 | 0.335938 | 0.335938 | 0.000000 |
+| 21 | 10.352200 | 10.352300 | 0.000100 | 0.380859 | 0.380859 | 0.000000 |
+| 22 | 10.365100 | 10.365000 | 0.000100 | 0.330078 | 0.330078 | 0.000000 |
+| 23 | 10.365000 | 10.365000 | 0.000000 | 0.363281 | 0.363281 | 0.000000 |
+| 24 | 10.352400 | 10.352500 | 0.000100 | 0.365234 | 0.365234 | 0.000000 |
+| 25 | 10.346100 | 10.346100 | 0.000000 | 0.451172 | 0.451172 | 0.000000 |
+| 26 | 10.353900 | 10.353800 | 0.000100 | 0.371094 | 0.371094 | 0.000000 |
+| 27 | 10.350900 | 10.350800 | 0.000100 | 0.384766 | 0.384766 | 0.000000 |
+| 28 | 10.363000 | 10.363300 | 0.000300 | 0.359375 | 0.359375 | 0.000000 |
+| 29 | 10.355400 | 10.355300 | 0.000100 | 0.361328 | 0.361328 | 0.000000 |
+| 30 | 10.361300 | 10.360500 | 0.000800 | 0.341797 | 0.341797 | 0.000000 |
+| 31 | 10.358800 | 10.358900 | 0.000100 | 0.351562 | 0.349609 | 0.001953 |
+| 32 | 10.358800 | 10.358900 | 0.000100 | 0.333984 | 0.333984 | 0.000000 |
+| 33 | 10.358200 | 10.358300 | 0.000100 | 0.343750 | 0.343750 | 0.000000 |
+| 34 | 10.339200 | 10.339300 | 0.000100 | 0.425781 | 0.425781 | 0.000000 |
+| 35 | 10.339200 | 10.339200 | 0.000000 | 0.408203 | 0.408203 | 0.000000 |
+| 36 | 10.364000 | 10.364000 | 0.000000 | 0.304688 | 0.304688 | 0.000000 |
+| 37 | 10.340300 | 10.340100 | 0.000200 | 0.402344 | 0.402344 | 0.000000 |
+| 38 | 10.356800 | 10.356700 | 0.000100 | 0.343750 | 0.345703 | 0.001953 |
+| 39 | 10.338900 | 10.339200 | 0.000300 | 0.478516 | 0.478516 | 0.000000 |
+| 40 | 10.355800 | 10.356000 | 0.000200 | 0.414062 | 0.414062 | 0.000000 |
+| 41 | 10.359100 | 10.358800 | 0.000300 | 0.351562 | 0.349609 | 0.001953 |
+| 42 | 10.363100 | 10.362700 | 0.000400 | 0.335938 | 0.335938 | 0.000000 |
+| 43 | 10.329000 | 10.329400 | 0.000400 | 0.458984 | 0.460938 | 0.001953 |
+| 44 | 10.352700 | 10.353000 | 0.000300 | 0.357422 | 0.359375 | 0.001953 |
+| 45 | 10.362200 | 10.361900 | 0.000300 | 0.343750 | 0.341797 | 0.001953 |
+| 46 | 10.319600 | 10.319500 | 0.000100 | 0.488281 | 0.488281 | 0.000000 |
+| 47 | 10.348700 | 10.348500 | 0.000200 | 0.367188 | 0.367188 | 0.000000 |
+| 48 | 10.342400 | 10.342000 | 0.000400 | 0.359375 | 0.361328 | 0.001953 |
+| 49 | 10.321900 | 10.322000 | 0.000100 | 0.486328 | 0.486328 | 0.000000 |
+| 50 | 10.368800 | 10.368500 | 0.000300 | 0.417969 | 0.417969 | 0.000000 |
+
+- _n_loop_it=2_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375000 | 0.375000 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.408203 | 0.408203 | 0.000000 |
+| 3 | 10.374700 | 10.374800 | 0.000100 | 0.408203 | 0.408203 | 0.000000 |
+| 4 | 10.379800 | 10.379900 | 0.000100 | 0.335938 | 0.335938 | 0.000000 |
+| 5 | 10.376600 | 10.376600 | 0.000000 | 0.353516 | 0.353516 | 0.000000 |
+| 6 | 10.363300 | 10.363300 | 0.000000 | 0.457031 | 0.457031 | 0.000000 |
+| 7 | 10.378900 | 10.378900 | 0.000000 | 0.326172 | 0.326172 | 0.000000 |
+| 8 | 10.372100 | 10.372000 | 0.000100 | 0.324219 | 0.324219 | 0.000000 |
+| 9 | 10.359900 | 10.360000 | 0.000100 | 0.460938 | 0.460938 | 0.000000 |
+| 10 | 10.369400 | 10.369300 | 0.000100 | 0.343750 | 0.343750 | 0.000000 |
+| 11 | 10.377400 | 10.377000 | 0.000400 | 0.322266 | 0.322266 | 0.000000 |
+| 12 | 10.362600 | 10.362600 | 0.000000 | 0.365234 | 0.365234 | 0.000000 |
+| 13 | 10.358400 | 10.358700 | 0.000300 | 0.384766 | 0.384766 | 0.000000 |
+| 14 | 10.363000 | 10.362900 | 0.000100 | 0.345703 | 0.345703 | 0.000000 |
+| 15 | 10.367900 | 10.368100 | 0.000200 | 0.337891 | 0.337891 | 0.000000 |
+| 16 | 10.360100 | 10.360100 | 0.000000 | 0.378906 | 0.378906 | 0.000000 |
+| 17 | 10.367700 | 10.367700 | 0.000000 | 0.326172 | 0.326172 | 0.000000 |
+| 18 | 10.362300 | 10.362100 | 0.000200 | 0.562500 | 0.566406 | 0.003906 |
+| 19 | 10.359400 | 10.359100 | 0.000300 | 0.343750 | 0.345703 | 0.001953 |
+| 20 | 10.363100 | 10.362900 | 0.000200 | 0.335938 | 0.335938 | 0.000000 |
+| 21 | 10.352100 | 10.352300 | 0.000200 | 0.380859 | 0.380859 | 0.000000 |
+| 22 | 10.365000 | 10.365000 | 0.000000 | 0.328125 | 0.330078 | 0.001953 |
+| 23 | 10.364900 | 10.365000 | 0.000100 | 0.363281 | 0.363281 | 0.000000 |
+| 24 | 10.352200 | 10.352500 | 0.000300 | 0.365234 | 0.365234 | 0.000000 |
+| 25 | 10.346000 | 10.346100 | 0.000100 | 0.451172 | 0.451172 | 0.000000 |
+| 26 | 10.354100 | 10.353800 | 0.000300 | 0.371094 | 0.371094 | 0.000000 |
+| 27 | 10.351000 | 10.350800 | 0.000200 | 0.382812 | 0.384766 | 0.001953 |
+| 28 | 10.363100 | 10.363300 | 0.000200 | 0.359375 | 0.359375 | 0.000000 |
+| 29 | 10.355300 | 10.355300 | 0.000000 | 0.359375 | 0.361328 | 0.001953 |
+| 30 | 10.361700 | 10.360500 | 0.001200 | 0.341797 | 0.341797 | 0.000000 |
+| 31 | 10.358700 | 10.358900 | 0.000200 | 0.351562 | 0.349609 | 0.001953 |
+| 32 | 10.358700 | 10.358900 | 0.000200 | 0.337891 | 0.333984 | 0.003906 |
+| 33 | 10.357800 | 10.358300 | 0.000500 | 0.343750 | 0.343750 | 0.000000 |
+| 34 | 10.339400 | 10.339300 | 0.000100 | 0.425781 | 0.425781 | 0.000000 |
+| 35 | 10.339500 | 10.339200 | 0.000300 | 0.408203 | 0.408203 | 0.000000 |
+| 36 | 10.363700 | 10.364000 | 0.000300 | 0.304688 | 0.304688 | 0.000000 |
+| 37 | 10.339900 | 10.340100 | 0.000200 | 0.402344 | 0.402344 | 0.000000 |
+| 38 | 10.356700 | 10.356700 | 0.000000 | 0.345703 | 0.345703 | 0.000000 |
+| 39 | 10.339200 | 10.339200 | 0.000000 | 0.480469 | 0.478516 | 0.001953 |
+| 40 | 10.355300 | 10.356000 | 0.000700 | 0.414062 | 0.414062 | 0.000000 |
+| 41 | 10.359000 | 10.358800 | 0.000200 | 0.351562 | 0.349609 | 0.001953 |
+| 42 | 10.362900 | 10.362700 | 0.000200 | 0.333984 | 0.335938 | 0.001953 |
+| 43 | 10.328600 | 10.329400 | 0.000800 | 0.460938 | 0.460938 | 0.000000 |
+| 44 | 10.353200 | 10.353000 | 0.000200 | 0.359375 | 0.359375 | 0.000000 |
+| 45 | 10.362200 | 10.361900 | 0.000300 | 0.343750 | 0.341797 | 0.001953 |
+| 46 | 10.319600 | 10.319500 | 0.000100 | 0.486328 | 0.488281 | 0.001953 |
+| 47 | 10.348400 | 10.348500 | 0.000100 | 0.365234 | 0.367188 | 0.001953 |
+| 48 | 10.342500 | 10.342000 | 0.000500 | 0.361328 | 0.361328 | 0.000000 |
+| 49 | 10.321700 | 10.322000 | 0.000300 | 0.486328 | 0.486328 | 0.000000 |
+| 50 | 10.369700 | 10.368500 | 0.001200 | 0.419922 | 0.417969 | 0.001953 |
+
+- _n_loop_it=4_
+
+| | loss | | | grad_norm | | |
+| --- | --------- | --------- | -------- | --------- | -------- | -------- |
+| | fused_cel | no-fused | absdiff | fused_cel | no-fused | absdiff |
+| 1 | 10.369300 | 10.369300 | 0.000000 | 0.375000 | 0.375000 | 0.000000 |
+| 2 | 10.383600 | 10.383600 | 0.000000 | 0.406250 | 0.408203 | 0.001953 |
+| 3 | 10.374700 | 10.374800 | 0.000100 | 0.408203 | 0.408203 | 0.000000 |
+| 4 | 10.379900 | 10.379900 | 0.000000 | 0.335938 | 0.335938 | 0.000000 |
+| 5 | 10.376600 | 10.376600 | 0.000000 | 0.353516 | 0.353516 | 0.000000 |
+| 6 | 10.363300 | 10.363300 | 0.000000 | 0.457031 | 0.457031 | 0.000000 |
+| 7 | 10.378900 | 10.378900 | 0.000000 | 0.326172 | 0.326172 | 0.000000 |
+| 8 | 10.372100 | 10.372000 | 0.000100 | 0.324219 | 0.324219 | 0.000000 |
+| 9 | 10.360000 | 10.360000 | 0.000000 | 0.460938 | 0.460938 | 0.000000 |
+| 10 | 10.369400 | 10.369300 | 0.000100 | 0.343750 | 0.343750 | 0.000000 |
+| 11 | 10.377300 | 10.377000 | 0.000300 | 0.322266 | 0.322266 | 0.000000 |
+| 12 | 10.362500 | 10.362600 | 0.000100 | 0.365234 | 0.365234 | 0.000000 |
+| 13 | 10.358500 | 10.358700 | 0.000200 | 0.384766 | 0.384766 | 0.000000 |
+| 14 | 10.362900 | 10.362900 | 0.000000 | 0.345703 | 0.345703 | 0.000000 |
+| 15 | 10.367800 | 10.368100 | 0.000300 | 0.337891 | 0.337891 | 0.000000 |
+| 16 | 10.360000 | 10.360100 | 0.000100 | 0.380859 | 0.378906 | 0.001953 |
+| 17 | 10.367800 | 10.367700 | 0.000100 | 0.326172 | 0.326172 | 0.000000 |
+| 18 | 10.362200 | 10.362100 | 0.000100 | 0.562500 | 0.566406 | 0.003906 |
+| 19 | 10.359300 | 10.359100 | 0.000200 | 0.343750 | 0.345703 | 0.001953 |
+| 20 | 10.363000 | 10.362900 | 0.000100 | 0.335938 | 0.335938 | 0.000000 |
+| 21 | 10.352000 | 10.352300 | 0.000300 | 0.380859 | 0.380859 | 0.000000 |
+| 22 | 10.364900 | 10.365000 | 0.000100 | 0.330078 | 0.330078 | 0.000000 |
+| 23 | 10.364800 | 10.365000 | 0.000200 | 0.363281 | 0.363281 | 0.000000 |
+| 24 | 10.352200 | 10.352500 | 0.000300 | 0.365234 | 0.365234 | 0.000000 |
+| 25 | 10.346400 | 10.346100 | 0.000300 | 0.451172 | 0.451172 | 0.000000 |
+| 26 | 10.354200 | 10.353800 | 0.000400 | 0.371094 | 0.371094 | 0.000000 |
+| 27 | 10.351000 | 10.350800 | 0.000200 | 0.384766 | 0.384766 | 0.000000 |
+| 28 | 10.363000 | 10.363300 | 0.000300 | 0.359375 | 0.359375 | 0.000000 |
+| 29 | 10.355300 | 10.355300 | 0.000000 | 0.361328 | 0.361328 | 0.000000 |
+| 30 | 10.361400 | 10.360500 | 0.000900 | 0.341797 | 0.341797 | 0.000000 |
+| 31 | 10.358500 | 10.358900 | 0.000400 | 0.351562 | 0.349609 | 0.001953 |
+| 32 | 10.358900 | 10.358900 | 0.000000 | 0.339844 | 0.333984 | 0.005859 |
+| 33 | 10.358000 | 10.358300 | 0.000300 | 0.343750 | 0.343750 | 0.000000 |
+| 34 | 10.339300 | 10.339300 | 0.000000 | 0.425781 | 0.425781 | 0.000000 |
+| 35 | 10.339300 | 10.339200 | 0.000100 | 0.408203 | 0.408203 | 0.000000 |
+| 36 | 10.363800 | 10.364000 | 0.000200 | 0.304688 | 0.304688 | 0.000000 |
+| 37 | 10.340000 | 10.340100 | 0.000100 | 0.402344 | 0.402344 | 0.000000 |
+| 38 | 10.356500 | 10.356700 | 0.000200 | 0.345703 | 0.345703 | 0.000000 |
+| 39 | 10.338800 | 10.339200 | 0.000400 | 0.478516 | 0.478516 | 0.000000 |
+| 40 | 10.356000 | 10.356000 | 0.000000 | 0.416016 | 0.414062 | 0.001953 |
+| 41 | 10.358800 | 10.358800 | 0.000000 | 0.349609 | 0.349609 | 0.000000 |
+| 42 | 10.362800 | 10.362700 | 0.000100 | 0.335938 | 0.335938 | 0.000000 |
+| 43 | 10.328900 | 10.329400 | 0.000500 | 0.460938 | 0.460938 | 0.000000 |
+| 44 | 10.353000 | 10.353000 | 0.000000 | 0.359375 | 0.359375 | 0.000000 |
+| 45 | 10.361400 | 10.361900 | 0.000500 | 0.343750 | 0.341797 | 0.001953 |
+| 46 | 10.320000 | 10.319500 | 0.000500 | 0.486328 | 0.488281 | 0.001953 |
+| 47 | 10.348200 | 10.348500 | 0.000300 | 0.365234 | 0.367188 | 0.001953 |
+| 48 | 10.342200 | 10.342000 | 0.000200 | 0.361328 | 0.361328 | 0.000000 |
+| 49 | 10.322400 | 10.322000 | 0.000400 | 0.486328 | 0.486328 | 0.000000 |
+| 50 | 10.369200 | 10.368500 | 0.000700 | 0.419922 | 0.417969 | 0.001953 |
+
+`Training metrics` for `bfloat16`
+| | step | trainable_params | total_params | n_loop_iters | total_flos | train_loss | train_mem_gpu_peaked_delta | train_samples_per_second | train_steps_per_second | train_runtime |
+|--------------|------|------------------|--------------|--------------|------------|------------|----------------------------|--------------------------|------------------------|---------------|
+| no-fused | 50 | 1032272 | 1032272 | 1 | 74GF | 10.3582 | 188MB | 24.8 | 12.4 | 0:00:04.03 |
+| fused_cel | 50 | 1032272 | 1032272 | 1 | 74GF | 10.3582 | 128MB | 24.564 | 12.282 | 0:00:04.07 |
+| fused_cel | 50 | 1032272 | 1032272 | 2 | 74GF | 10.3582 | 98MB | 29.51 | 14.755 | 0:00:03.38 |
+| fused_cel | 50 | 1032272 | 1032272 | 4 | 74GF | 10.3582 | 49MB | 31.764 | 15.882 | 0:00:03.14 |
+
+## Next Steps
+
+- [ ] Integrate with `FastLanguageModel`
+- [ ] Run tests / benchmarks on `LoRA` and `QLoRA` configs
diff --git a/tests/fused_cel/benchmark_hf_test_cel.py b/tests/fused_cel/benchmark_hf_test_cel.py
new file mode 100644
index 00000000..1e609b09
--- /dev/null
+++ b/tests/fused_cel/benchmark_hf_test_cel.py
@@ -0,0 +1,229 @@
+import argparse
+import os
+from pathlib import Path
+
+import pandas as pd
+import torch
+from cel_analysis import load_log_diffs
+from cel_test_utils import (
+ get_model,
+ get_peft_config,
+ get_quant_config,
+ get_sft_trainer,
+ get_tokenizer,
+ get_trainer_args,
+)
+
+# from transformers.trainer_utils import enable_full_determinism
+from transformers.trainer_utils import set_seed as hf_set_seed
+
+import unsloth.utils.data as data_utils
+from unsloth.kernels.fused_cel import patch_model as patch_model_fused_cel
+from unsloth.models._utils import patch_tokenizer, prepare_model_for_kbit_training
+from unsloth.utils.memory import empty_cache
+
+parent_dir = Path(__file__).parent.absolute()
+SEED = 3407
+hf_set_seed(SEED)
+torch.autograd.set_detect_anomaly(True)
+
+import logging
+
+logging.basicConfig(level=logging.INFO)
+
+
+def run_train_loop(
+ model,
+ tokenizer,
+ dataset,
+ peft_config,
+ training_args,
+ cli_args,
+ use_fused_cel=False,
+ n_loop_iters=1,
+):
+ model = patch_model_fused_cel(
+ model,
+ use_fused_cel=use_fused_cel,
+ fused_cel_n_loop_iters=n_loop_iters,
+ # these are defaults
+ fused_cel_ignore_index=-100,
+ fused_cel_reduction="mean",
+ )
+ file_prefix = (
+ ("fused" if use_fused_cel else "base")
+ + "_"
+ + cli_args.dtype
+ + "_"
+ + str(n_loop_iters)
+ )
+
+ trainer = get_sft_trainer(
+ model=model,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ peft_config=peft_config,
+ trainer_args=training_args,
+ max_seq_len=cli_args.max_seq_len,
+ packing=cli_args.packing,
+ file_prefix=file_prefix,
+ )
+ _ = trainer.train()
+
+
+def get_model_and_tokenizer(args):
+ dtype = getattr(torch, args.dtype)
+ model_id = args.model_id
+
+ quant_config = (
+ get_quant_config(args.load_in_4bit, dtype) if args.load_in_4bit else None
+ )
+ model = get_model(
+ model_id=model_id,
+ dtype=dtype,
+ use_fused_cel_layer=True,
+ quant_config=quant_config,
+ )
+ tokenizer = get_tokenizer(model_id, args.max_seq_len)
+ model, tokenizer = patch_tokenizer(model, tokenizer)
+
+ return model, tokenizer
+
+
+def run_benchmark(args):
+ dtype = getattr(torch, args.dtype)
+ model, tokenizer = get_model_and_tokenizer(args)
+
+ if args.overwrite_output_dir:
+ import shutil
+
+ shutil.rmtree(args.output_dir, ignore_errors=True)
+
+ training_args = get_trainer_args(
+ batch_size=args.batch_size,
+ max_steps=args.max_steps,
+ grad_accum_steps=args.grad_accum_steps,
+ dtype=dtype,
+ seed=SEED,
+ output_dir=args.output_dir,
+ )
+ peft_config = get_peft_config() if args.use_lora or args.load_in_4bit else None
+ if args.load_in_4bit:
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+
+ dataset = data_utils.get_alpaca(tokenizer)
+
+ formatted_args = "\n ".join([f"{k}={v}" for k, v in vars(args).items()])
+
+ print(f"Running with:\n {formatted_args}")
+
+ losses, metrics = [], []
+ # Run reference once
+ run_train_loop(
+ model=model,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ peft_config=peft_config,
+ training_args=training_args,
+ cli_args=args,
+ use_fused_cel=False,
+ )
+ del model
+ del tokenizer
+ empty_cache()
+
+ for n_loop_iters in args.fused_cel_n_loop_iters:
+ # Run with fused CEL
+ model, tokenizer = get_model_and_tokenizer(args)
+ run_train_loop(
+ model=model,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ peft_config=peft_config,
+ training_args=training_args,
+ cli_args=args,
+ use_fused_cel=True,
+ n_loop_iters=n_loop_iters,
+ )
+ loss_df, metrics_df = load_log_diffs(args.output_dir)
+ loss_df.columns.names = [
+ loss_df.columns.names[0] + ", n_loop_it=" + str(n_loop_iters),
+ loss_df.columns.names[1],
+ ]
+ losses.append(loss_df)
+ # No fused always has n_loop_iters = 1
+ metrics_df.loc["n_loop_iters"] = [1, n_loop_iters]
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ metrics_df.loc["trainable_params"] = [
+ trainable_params,
+ trainable_params,
+ ]
+ metrics_df.loc["total_params"] = [
+ total_params,
+ total_params,
+ ]
+ metrics.append(metrics_df)
+ if args.print_accuracy:
+ print(loss_df.to_string(float_format="%.6f", justify="left"))
+
+ consolidated_metrics = pd.concat(metrics, axis=1).T.drop_duplicates()
+ COL_ORDER = [
+ "step",
+ "trainable_params",
+ "total_params",
+ "n_loop_iters",
+ "total_flos",
+ "train_loss",
+ "train_mem_gpu_peaked_delta",
+ "train_samples_per_second",
+ "train_steps_per_second",
+ "train_runtime",
+ ]
+ consolidated_metrics = consolidated_metrics[COL_ORDER]
+ consolidated_metrics.to_csv(os.path.join(args.output_dir, "metrics.csv"))
+ if args.print_metrics:
+ print(consolidated_metrics.to_string())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--max_steps", type=int, default=10, help="Number of training steps"
+ )
+ parser.add_argument(
+ "--dtype", type=str, default="bfloat16", help="torch compute type"
+ )
+ parser.add_argument(
+ "--model_id",
+ type=str,
+ default="hf-internal-testing/tiny-random-LlamaForCausalLM",
+ help="Path to the model, passed to huggingface `from_pretrained` method",
+ )
+ parser.add_argument("--batch_size", type=int, default=2)
+ parser.add_argument("--max_seq_len", type=int, default=256)
+ parser.add_argument("--packing", action="store_true", default=True)
+ parser.add_argument("--grad_accum_steps", type=int, default=1)
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
+ parser.add_argument("--use_lora", action="store_true", default=False)
+ parser.add_argument("--output_dir", type=str, default="outputs")
+ parser.add_argument("--overwrite_output_dir", action="store_true", default=True)
+ parser.add_argument("--print_accuracy", action="store_true", default=True)
+ parser.add_argument("--print_metrics", action="store_true", default=True)
+
+ parser.add_argument(
+ "--fused_cel_n_loop_iters",
+ type=int,
+ nargs="+",
+ default=[1, 2, 4],
+ help="""Number of loop iterations for fused CEL.
+ E.g., `n_loop_iters=4` will calculate the logits / loss in 4 chunks along sequence length.
+ `batch_size * seqlen` must be divisible by `n_loop_iters`
+ """,
+ )
+ args = parser.parse_args()
+ run_benchmark(args)
diff --git a/tests/fused_cel/cel_analysis.py b/tests/fused_cel/cel_analysis.py
new file mode 100644
index 00000000..c808224d
--- /dev/null
+++ b/tests/fused_cel/cel_analysis.py
@@ -0,0 +1,68 @@
+import json
+import os
+
+import pandas as pd
+
+idx = pd.IndexSlice
+
+
+def load_log_history(
+ file, return_df=True, return_cols=["model", "step", "loss", "grad_norm"]
+):
+ """
+ Load log history from json file
+ """
+ log_history = json.load(open(file))["log_history"]
+
+ losses, metrics = log_history[:-1], log_history[-1]
+
+ if return_df:
+ loss_df = pd.DataFrame(losses)
+ model_type = "fused_cel" if "fused" in file else "no-fused"
+ loss_df["model"] = model_type
+ loss_df = loss_df[return_cols]
+ metrics_df = pd.Series(metrics).to_frame(name=model_type).loc[idx["step":], :]
+ return (loss_df, metrics_df) if return_df else (losses, metrics)
+
+
+def get_diff(pivoted_df, col, diff1, diff2):
+ return abs(pivoted_df[col][diff1] - pivoted_df[col][diff2])
+
+
+def get_pivoted_df(df):
+ pivot = df.pivot(index="step", columns="model", values=["loss", "grad_norm"])
+ pivot.columns.names = ["metric", "model"]
+ loss_diff = get_diff(pivot, "loss", "no-fused", "fused_cel")
+ grad_diff = get_diff(pivot, "grad_norm", "no-fused", "fused_cel")
+ pivot.loc[:, ("loss", "absdiff")] = loss_diff
+ pivot.loc[:, ("grad_norm", "absdiff")] = grad_diff
+ return pd.concat(
+ [pivot.loc[:, idx["loss", :]], pivot.loc[:, idx["grad_norm", :]]], axis=1
+ )
+
+
+def load_log_diffs(
+ trace_dir, return_df=True, return_cols=["model", "step", "loss", "grad_norm"]
+):
+ traces = [
+ os.path.join(trace_dir, trace)
+ for trace in sorted(os.listdir(trace_dir), reverse=True) # Load most recent
+ if trace.endswith(".json")
+ ]
+ base_traces = [trace for trace in traces if "fused" not in trace]
+ fused_traces = [trace for trace in traces if "fused" in trace]
+
+ losses = []
+ metrics = []
+ for trace in [base_traces[0], fused_traces[0]]:
+ loss, metric = load_log_history(
+ trace, return_df=return_df, return_cols=return_cols
+ )
+ losses.append(loss)
+ metrics.append(metric)
+
+ if return_df:
+ losses = pd.concat(losses)
+ losses = get_pivoted_df(losses)
+ metrics = pd.concat(metrics, axis=1)
+ return (losses, metrics)
diff --git a/tests/fused_cel/cel_test_utils.py b/tests/fused_cel/cel_test_utils.py
new file mode 100644
index 00000000..51d2e049
--- /dev/null
+++ b/tests/fused_cel/cel_test_utils.py
@@ -0,0 +1,146 @@
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from cel_analysis import load_log_diffs
+from peft import LoraConfig
+from tabulate import tabulate
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ TrainingArguments,
+)
+from transformers.models.llama import LlamaConfig
+from transformers.trainer_callback import ProgressCallback
+from transformers.trainer_utils import enable_full_determinism
+from transformers.trainer_utils import set_seed as hf_set_seed
+from trl import SFTTrainer
+
+import unsloth.utils.data as data_utils
+from unsloth.kernels.fused_cel import LlamaForCausalLMFusedCEL
+from unsloth.kernels.fused_cel import patch_model as patch_model_fused_cel
+from unsloth.models._utils import patch_tokenizer, prepare_model_for_kbit_training
+from unsloth.utils.data import get_data_loader
+from unsloth.utils.profiling import MetricsCallBack
+
+parent_dir = Path(__file__).parent.absolute()
+# SEED = 3407
+# enable_full_determinism(SEED)
+torch.autograd.set_detect_anomaly(True)
+
+
+def get_quant_config(load_in_4bit, dtype):
+ return BitsAndBytesConfig(
+ load_in_4bit=load_in_4bit,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_use_double_quant=False,
+ bnb_4bit_compute_dtype=dtype,
+ )
+
+
+def get_model(model_id, dtype, use_fused_cel_layer=True, quant_config=None):
+ model_cls = (
+ LlamaForCausalLMFusedCEL if use_fused_cel_layer else AutoModelForCausalLM
+ )
+
+ model = model_cls.from_pretrained(
+ model_id,
+ quantization_config=quant_config,
+ torch_dtype=dtype,
+ )
+ return model
+
+
+def get_tokenizer(model_id, max_seq_len):
+ tokenizer = AutoTokenizer.from_pretrained(
+ "meta-llama/Llama-2-7b-chat-hf",
+ model_max_length=max_seq_len,
+ padding_side="right",
+ )
+ return tokenizer
+
+
+def get_trainer_args(batch_size, grad_accum_steps, max_steps, dtype, seed, output_dir):
+ training_args = TrainingArguments(
+ per_device_train_batch_size=batch_size,
+ gradient_accumulation_steps=grad_accum_steps,
+ warmup_steps=1,
+ max_steps=max_steps,
+ learning_rate=2e-4,
+ fp16=dtype == "float16",
+ bf16=dtype == "bfloat16",
+ logging_steps=1,
+ optim="adamw_8bit",
+ weight_decay=0.01,
+ lr_scheduler_type="linear",
+ seed=seed,
+ data_seed=seed,
+ output_dir=output_dir,
+ overwrite_output_dir=True,
+ report_to="none",
+ # Metrics
+ skip_memory_metrics=False,
+ )
+ return training_args
+
+
+def get_peft_config(
+ target_modules="all-linear",
+ lora_alpha=8,
+ lora_dropout=0.0,
+ bias="none",
+ task_type="CAUSAL_LM",
+):
+ # accepted_modules = frozenset(
+ # (
+ # "q_proj",
+ # "k_proj",
+ # "v_proj",
+ # "o_proj",
+ # "gate_proj",
+ # "up_proj",
+ # "down_proj",
+ # ),
+ # )
+
+ peft_config = LoraConfig(
+ target_modules=target_modules,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ bias=bias,
+ task_type=task_type,
+ )
+ return peft_config
+
+
+def get_sft_trainer(
+ model,
+ tokenizer,
+ dataset,
+ peft_config,
+ trainer_args,
+ max_seq_len,
+ file_prefix,
+ packing=False,
+):
+ trainer = SFTTrainer(
+ model=model,
+ tokenizer=tokenizer,
+ train_dataset=dataset,
+ peft_config=peft_config,
+ dataset_text_field="text",
+ max_seq_length=max_seq_len,
+ dataset_num_proc=2,
+ packing=packing,
+ args=trainer_args,
+ )
+
+ # Remove default callbacks, make less verbose
+ trainer.remove_callback(ProgressCallback)
+ trainer.model.enable_input_require_grads()
+ # file_prefix = "fused_cel" if use_fused_cel else ""
+ # file_prefix += "_" + args.dtype
+ _ = trainer.add_callback(MetricsCallBack(name=file_prefix, verbose=False))
+ return trainer
diff --git a/tests/fused_cel/conftest.py b/tests/fused_cel/conftest.py
new file mode 100644
index 00000000..e9994986
--- /dev/null
+++ b/tests/fused_cel/conftest.py
@@ -0,0 +1,2 @@
+def pytest_make_parametrize_id(config, val, argname):
+ return f"{argname}={val}"
diff --git a/tests/fused_cel/llama-small.json b/tests/fused_cel/llama-small.json
new file mode 100644
index 00000000..58624945
--- /dev/null
+++ b/tests/fused_cel/llama-small.json
@@ -0,0 +1,18 @@
+{
+ "architectures": ["LLaMAForCausalLM"],
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "hidden_act": "silu",
+ "hidden_size": 128,
+ "intermediate_size": 352,
+ "initializer_range": 0.02,
+ "max_sequence_length": 1024,
+ "model_type": "llama",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 4,
+ "pad_token_id": -1,
+ "rms_norm_eps": 1e-6,
+ "transformers_version": "4.28.1",
+ "use_cache": true,
+ "vocab_size": 320000
+}
diff --git a/tests/fused_cel/pytest.ini b/tests/fused_cel/pytest.ini
new file mode 100644
index 00000000..c24fe5bb
--- /dev/null
+++ b/tests/fused_cel/pytest.ini
@@ -0,0 +1,3 @@
+[pytest]
+filterwarnings =
+ ignore::DeprecationWarning
diff --git a/tests/fused_cel/test_CEL.py b/tests/fused_cel/test_CEL.py
new file mode 100644
index 00000000..a2d7d776
--- /dev/null
+++ b/tests/fused_cel/test_CEL.py
@@ -0,0 +1,196 @@
+import itertools
+import types
+from pathlib import Path
+
+import pytest
+import torch
+from transformers.models.llama import LlamaConfig, LlamaForCausalLM
+
+import unsloth.utils.testing as test_utils
+from unsloth.kernels.fused_cel import fused_cel_layer
+from unsloth.utils.memory import empty_cache
+
+torch.manual_seed(0)
+
+
+@pytest.fixture
+def model_path():
+ PARENT_DIR = Path(__file__).parent.absolute()
+ MODEL_CONFIG_PATH = PARENT_DIR / "llama-small.json"
+
+ return MODEL_CONFIG_PATH
+
+
+@pytest.fixture(scope="module")
+def tensors(bs, seqlen, hidden_size, vocab_size, dtype):
+ dtype = getattr(torch, dtype)
+ hidden_states = torch.randn(
+ bs, seqlen, hidden_size, dtype=dtype, device="cuda", requires_grad=True
+ )
+ lm_head_weight = torch.randn(
+ (vocab_size, hidden_size), dtype=dtype, device="cuda", requires_grad=True
+ )
+ labels = torch.randint(0, vocab_size, size=(bs, seqlen), device="cuda")
+ yield hidden_states, labels, lm_head_weight
+ # Cleanup
+ del hidden_states, labels, lm_head_weight
+ empty_cache()
+
+
+def ref_cel(hidden_states, lm_head_weight, labels):
+ vocab_size = lm_head_weight.shape[0]
+ logits = hidden_states @ lm_head_weight.T
+ logits = logits.float()
+
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = torch.nn.CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, vocab_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ return loss
+
+
+def run_cel(fn, hidden_states, lm_head_weight, labels, **kwargs):
+ loss = fn(hidden_states, lm_head_weight, labels, **kwargs)
+ dX, dW = torch.autograd.grad(loss, [hidden_states, lm_head_weight])
+ return loss, dX, dW
+
+
+BATCH_SIZES = [1]
+SEQ_LENS = [256]
+HIDDEN_SIZES = [128, 4096]
+VOCAB_SIZES = [32000, 128256]
+DTYPE = ["float16", "bfloat16", "float32"]
+N_LOOP_ITERS = [1, 2]
+TEST_CONFIGS = list(
+ itertools.product(
+ BATCH_SIZES, SEQ_LENS, HIDDEN_SIZES, VOCAB_SIZES, DTYPE, N_LOOP_ITERS
+ )
+)
+
+
+# Test will fail for dX when hidden_size > 4096 and n_loop_iters > 1 and vocab_size == 32000
+# Comment out the pytest.skip if you want to run
+# Also, the pytest does not release all resources, leading to memory leaks and OOM, which is why we comment out
+# some test configs
+@pytest.mark.parametrize(
+ "bs, seqlen, hidden_size, vocab_size, dtype, n_loop_iters", TEST_CONFIGS
+)
+def test_cel(bs, seqlen, hidden_size, vocab_size, dtype, n_loop_iters):
+ dtype = getattr(torch, dtype)
+
+ # Accuracy failure case
+ # if not (hidden_size >= 4096 and n_loop_iters > 1 and vocab_size == 32000):
+ # pytest.skip("Skipping, failure case for dX, uncomment to run only these cases")
+
+ if vocab_size > 32000 and dtype == "float32":
+ pytest.skip("No need for float32 for large vocabs")
+ hidden_states = torch.randn(
+ bs, seqlen, hidden_size, dtype=dtype, device="cuda", requires_grad=True
+ )
+
+ lm_head_weight = torch.randn(
+ (vocab_size, hidden_size), dtype=dtype, device="cuda", requires_grad=True
+ )
+
+ # Input ids aren't actually used, but we need to pass them to the model
+ labels = torch.randint(0, vocab_size, size=(bs, seqlen), device="cuda")
+
+ # Reference loss, dX, dW where dX is the gradients wrt to the hidden states and dW is the gradients wrt to the LM head weight
+ loss, dX, dW = run_cel(ref_cel, hidden_states, lm_head_weight, labels)
+ fused_loss, dX_fused, dW_fused = run_cel(
+ fused_cel_layer,
+ hidden_states,
+ lm_head_weight,
+ labels,
+ n_loop_iters=n_loop_iters,
+ ignore_index=-100,
+ reduction="mean",
+ )
+ if dtype == torch.bfloat16:
+ atol, rtol = 1e-3, 1e-3
+ elif dtype == torch.float16:
+ atol, rtol = 1e-4, 1e-4
+ else:
+ atol, rtol = 1e-6, 1e-6
+
+ test_utils.check_all(
+ [loss, dX, dW],
+ [fused_loss, dX_fused, dW_fused],
+ ["loss", "dX", "dW"],
+ atol=atol,
+ rtol=rtol,
+ )
+ del loss, dX, dW
+ del fused_loss, dX_fused, dW_fused
+ del hidden_states, lm_head_weight, labels
+ empty_cache()
+
+
+# @pytest.mark.parametrize("bs", [1]) # , 2, 4])
+# @pytest.mark.parametrize("seqlen", [256]) # , 512, 1024])
+# @pytest.mark.parametrize("hidden_size", [4096])
+# @pytest.mark.parametrize(
+# "vocab_size",
+# [32000, 128256], # , 256000]
+# ) # llama-2, llama-3, gemma
+# @pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"])
+# @pytest.mark.parametrize("n_loop_iters", [1, 2]) # , 2, 4])
+# def test_cel(bs, seqlen, hidden_size, vocab_size, dtype, n_loop_iters, model_path):
+# dtype = getattr(torch, dtype)
+
+# model_config = LlamaConfig.from_pretrained(model_path)
+# model_config.update({"vocab_size": vocab_size})
+# model_config.update({"hidden_size": hidden_size})
+
+# model = LlamaForCausalLM(model_config).to(dtype).to("cuda")
+
+# # Mock LlamaModel.forward so that we can directly test the CEL loss and derivatives wrt the hidden states (input to the LM head)
+# hidden_states = torch.randn(
+# bs, seqlen, hidden_size, dtype=dtype, device="cuda", requires_grad=True
+# )
+# model.model.forward = types.MethodType(
+# lambda *args, **kwargs: (hidden_states,), model.model
+# )
+
+# # Input ids aren't actually used, but we need to pass them to the model
+# input_ids = torch.randint(0, vocab_size, size=(bs, seqlen), device="cuda")
+# labels = input_ids.detach().clone()
+# attention_mask = torch.ones((bs, seqlen), device="cuda")
+
+# # Reference loss, dX, dW where dX is the gradients wrt to the hidden states and dW is the gradients wrt to the LM head weight
+# loss, *_ = model(
+# input_ids, labels=labels, attention_mask=attention_mask, return_dict=False
+# )
+# dX, dW = torch.autograd.grad(loss, [hidden_states, model.lm_head.weight])
+
+# # Patch the model to use fused CEL
+# fused_model = patch_model_fused_cel(
+# model, use_fused_cel=True, fused_cel_n_loop_iters=n_loop_iters
+# )
+# fused_loss, *_ = fused_model(
+# input_ids, labels=labels, attention_mask=attention_mask, return_dict=False
+# )
+# dX_fused, dW_fused = torch.autograd.grad(
+# fused_loss, [hidden_states, fused_model.lm_head.weight]
+# )
+
+# if dtype == torch.bfloat16:
+# atol, rtol = 1e-3, 1e-3 # Fails if < 1e-3
+# elif dtype == torch.float16:
+# atol, rtol = 1e-4, 1e-4 # Fails if < 1e-4
+# else:
+# atol, rtol = 1e-6, 1e-6
+
+# test_utils.check_all(
+# [loss, dX, dW],
+# [fused_loss, dX_fused, dW_fused],
+# ["loss", "dX", "dW"],
+# atol=atol,
+# rtol=rtol,
+# )
+# del fused_model
+# empty_cache()
diff --git a/unsloth/__init__.py b/unsloth/__init__.py
index d4ca45d7..9405ce07 100644
--- a/unsloth/__init__.py
+++ b/unsloth/__init__.py
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
import os
import warnings
-import importlib
# Currently only supports 1 GPU, or else seg faults will occur.
if "CUDA_VISIBLE_DEVICES" in os.environ:
@@ -22,8 +22,8 @@
if not devices.isdigit():
first_id = devices.split(",")[0]
warnings.warn(
- f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
- "Multiple CUDA devices detected but we require a single device.\n"\
+ f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"
+ "Multiple CUDA devices detected but we require a single device.\n"
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
)
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
@@ -39,17 +39,21 @@
try:
import torch
except:
- raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\
- "We have some installation instructions on our Github page.")
+ raise ImportError(
+ "Pytorch is not installed. Go to https://pytorch.org/.\n"
+ "We have some installation instructions on our Github page."
+ )
# We support Pytorch 2
# Fixes https://github.com/unslothai/unsloth/issues/38
torch_version = torch.__version__.split(".")
major_torch, minor_torch = torch_version[0], torch_version[1]
major_torch, minor_torch = int(major_torch), int(minor_torch)
-if (major_torch < 2):
- raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
- "We have some installation instructions on our Github page.")
+if major_torch < 2:
+ raise ImportError(
+ "Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"
+ "We have some installation instructions on our Github page."
+ )
elif (major_torch == 2) and (minor_torch < 2):
# Disable expandable_segments
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
@@ -59,25 +63,36 @@
# Try loading bitsandbytes and triton
import bitsandbytes as bnb
import triton
-from triton.common.build import libcuda_dirs
+
+triton_version = triton.__version__.split(".")
+triton_major = int(triton_version[0])
+if triton_major >= 3:
+ from triton.backends.nvidia.driver import libcuda_dirs
+else:
+ from triton.common.build import libcuda_dirs
+
+
import os
import re
-import numpy as np
import subprocess
+import numpy as np
+
try:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
- warnings.warn(
- "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
- )
+ warnings.warn("Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.")
if os.path.exists("/usr/lib64-nvidia"):
os.system("ldconfig /usr/lib64-nvidia")
elif os.path.exists("/usr/local"):
# Sometimes bitsandbytes cannot be linked properly in Runpod for example
- possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
+ possible_cudas = (
+ subprocess.check_output(["ls", "-al", "/usr/local"])
+ .decode("utf-8")
+ .split("\n")
+ )
find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
possible_cudas = [find_cuda.search(x) for x in possible_cudas]
possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
@@ -87,7 +102,9 @@
os.system(f"ldconfig /usr/local/")
else:
find_number = re.compile(r"([\d\.]{2,})")
- latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
+ latest_cuda = np.argsort(
+ [float(find_number.search(x).group(1)) for x in possible_cudas]
+ )[::-1][0]
latest_cuda = possible_cudas[latest_cuda]
os.system(f"ldconfig /usr/local/{latest_cuda}")
pass
@@ -97,20 +114,21 @@
try:
import bitsandbytes as bnb
from triton.common.build import libcuda_dirs
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
- "Unsloth: CUDA is not linked properly.\n"\
- "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
- "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
- "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
- "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
+ "Unsloth: CUDA is not linked properly.\n"
+ "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"
+ "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"
+ "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"
+ "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
)
pass
+from .chat_templates import *
from .models import *
from .save import *
-from .chat_templates import *
from .tokenizer_utils import *
diff --git a/unsloth/kernels/fused_cel.py b/unsloth/kernels/fused_cel.py
new file mode 100644
index 00000000..f38a97c8
--- /dev/null
+++ b/unsloth/kernels/fused_cel.py
@@ -0,0 +1,459 @@
+import logging
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+import triton
+import triton.language as tl
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.models.llama import LlamaForCausalLM
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FusedCELConfig(dict):
+ use_fused_cel: bool = True
+ n_loop_iters: int = 1
+ ignore_index: int = -100
+ reduction: str = "mean"
+
+ def __post_init__(self):
+ self.update(self.__dict__)
+
+
+# Efficient Cross Entropy Fused Kernel
+# credit: https://github.com/mgmalek/efficient_cross_entropy
+
+
+@triton.jit
+def fused_cross_entropy_fwd_bwd_kernel(
+ output_loss_ptr,
+ output_logit_grad_ptr,
+ input_logit_ptr,
+ input_targ_ptr,
+ input_divisor_ptr,
+ output_loss_stride,
+ output_logit_grad_stride,
+ input_logit_stride,
+ input_targ_stride,
+ n_cols,
+ ignore_index,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # Get pointers to current row for all inputs/outputs
+ row_idx = tl.program_id(0)
+ logit_grad_row_start_ptr = (
+ output_logit_grad_ptr + row_idx * output_logit_grad_stride
+ )
+ logit_row_start_ptr = input_logit_ptr + row_idx * input_logit_stride
+ targ_ptr = input_targ_ptr + row_idx * input_targ_stride
+ loss_ptr = output_loss_ptr + row_idx * output_loss_stride
+
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ logit_row_ptrs = logit_row_start_ptr + col_offsets
+ logit_grad_row_ptrs = logit_grad_row_start_ptr + col_offsets
+
+ # Load data into SRAM
+ logit_row_unnormalized = tl.load(
+ logit_row_ptrs, mask=col_offsets < n_cols, other=float("-Inf")
+ )
+ targ = tl.load(targ_ptr)
+ divisor = tl.load(input_divisor_ptr)
+
+ # Normalize logits and compute some useful intermediate values
+ logit_row = logit_row_unnormalized - tl.max(
+ logit_row_unnormalized, axis=0
+ ) # Subtract max value for numerical stability
+ exp_logit_row = tl.exp(logit_row)
+ sum_exp_logit_row = tl.sum(exp_logit_row, axis=0)
+
+ # Compute loss
+ log_sum_exp_logit_row = tl.log(sum_exp_logit_row)
+ logit_gt_logit = tl.sum(tl.where(targ == col_offsets, logit_row, 0.0))
+ loss = log_sum_exp_logit_row - logit_gt_logit
+ loss = loss / divisor
+ loss = tl.where(targ == ignore_index, 0.0, loss)
+ tl.store(loss_ptr, loss)
+
+ # Compute gradients
+ targ_one_hot = tl.where(targ == col_offsets, 1.0, 0.0)
+ grad = exp_logit_row / sum_exp_logit_row - targ_one_hot
+ grad = grad / divisor
+ grad = tl.where(targ == ignore_index, 0.0, grad)
+ tl.store(logit_grad_row_ptrs, grad, mask=col_offsets < n_cols)
+
+
+"""
+NOTE: Changes from original implementation:
+- Reshape inputs within forward from bs x seqlen x hidden_dim to (bs * seqlen) x hidden_dim per kernel requirement
+- Reshape labels within forward from bs x seqlen to (bs * seqlen)
+- Upcast `loss` to float32 (originally initialized to autocast / in-feat dtype)
+- Upcast 'divisor' to float32 (originally initialized to autocast / in-feat dtype)
+- Reshape dX from `(bs * seqlen) x hidden_dim` to `bs x seqlen x hidden_dim`
+- Add custom_fwd / custom_bwd
+- Handle torch.float16 scaling in backward
+
+TODO:
+- Revisit float16 scaling in `backward`
+- Investigate why in_feat.view(-1, in_feat.shape[-1]) sometimes changes in `in_feat.requires_grad` to False
+"""
+
+
+class FusedCrossEntropyLossFunction(torch.autograd.Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(
+ ctx,
+ in_feat: torch.Tensor,
+ proj_weight: torch.Tensor,
+ targ: torch.Tensor,
+ n_loop_iters: int,
+ ignore_index: int,
+ reduction: str,
+ ):
+ bs, seqlen, hidden_dim = in_feat.shape
+
+ in_feat = in_feat.view(-1, in_feat.shape[-1])
+ # print(
+ # f"in_feat shape, contiguity, grad: {in_feat.shape}, {in_feat.is_contiguous(), in_feat.requires_grad}"
+ # )
+
+ in_feat.requires_grad_(True)
+ targ = targ.view(-1)
+
+ n_tokens = in_feat.shape[0]
+ n_classes = proj_weight.shape[0]
+ # print(
+ # f"proj_weight shape, contiguity, grad: {proj_weight.shape}, {proj_weight.is_contiguous(), proj_weight.requires_grad}"
+ # )
+
+ # print(f"n_tokens: {n_tokens}, n_classes: {n_classes}")
+ assert in_feat.ndim == 2, in_feat.ndim
+ assert proj_weight.ndim == 2, proj_weight.ndim
+ assert targ.ndim == 1, targ.shape
+ assert (
+ in_feat.shape[0] == targ.shape[0]
+ ), f"Number of tokens in in_feat and targ is not equal: {(in_feat.shape, targ.shape) = }"
+ assert reduction in ("mean", "sum"), reduction
+ assert n_loop_iters > 0, n_loop_iters
+ assert (
+ n_tokens % n_loop_iters == 0
+ ), f"Number of tokens must be divisible by n_loop_iters {(n_tokens, n_loop_iters)}"
+ NUM_WARPS = 16
+
+ BLOCK_SIZE = triton.next_power_of_2(n_classes)
+
+ # Change loss from in_feat.dtype to float32
+ loss = torch.empty(n_tokens, dtype=torch.float32, device=in_feat.device)
+ dtype = (
+ torch.get_autocast_gpu_dtype()
+ if torch.is_autocast_enabled()
+ else in_feat.dtype
+ )
+
+ if proj_weight.requires_grad:
+ grad_proj_weight = torch.zeros_like(proj_weight, dtype=dtype)
+ else:
+ grad_proj_weight = None
+
+ if in_feat.requires_grad:
+ grad_in_feat = torch.zeros_like(in_feat)
+ else:
+ grad_in_feat = None
+
+ # Change divisor from in_feat.dtype to float32
+ divisor = (
+ (targ != ignore_index).sum().to(torch.float32)
+ if reduction == "mean"
+ else torch.ones(1, dtype=torch.float32, device=in_feat.device)
+ )
+
+ # Divide the input into chunks of size num_tokens // n_loop_iters, then compute the loss for each of these groups
+ proj_weight_cast = proj_weight.to(dtype)
+
+ loop_chunk_size = triton.cdiv(n_tokens, n_loop_iters)
+ logits_chunk_cast = torch.zeros(
+ (loop_chunk_size, n_classes), dtype=dtype, device=in_feat.device
+ )
+
+ for i, in_feat_chunk in enumerate(torch.split(in_feat, loop_chunk_size)):
+ token_start_idx = i * loop_chunk_size
+ token_end_idx = (i + 1) * loop_chunk_size
+
+ in_feat_chunk = in_feat_chunk.to(dtype)
+
+ # Compute logits
+ torch.matmul(in_feat_chunk, proj_weight_cast.T, out=logits_chunk_cast)
+ logits_chunk = logits_chunk_cast.float()
+ # print(f"Fused cel logits_chunk: {logits_chunk.mean()}")
+ # Compute loss
+ loss_chunk = loss[token_start_idx:token_end_idx]
+ targ_chunk = targ[token_start_idx:token_end_idx]
+
+ n_tokens_chunk = logits_chunk.shape[0]
+ grad_logits_chunk = (
+ logits_chunk # NOTE: we override the logits with their gradients
+ )
+
+ fused_cross_entropy_fwd_bwd_kernel[(n_tokens_chunk,)](
+ loss_chunk,
+ grad_logits_chunk,
+ logits_chunk,
+ targ_chunk,
+ divisor,
+ loss_chunk.stride(0),
+ grad_logits_chunk.stride(0),
+ logits_chunk.stride(0),
+ targ_chunk.stride(0),
+ n_classes,
+ ignore_index,
+ num_warps=NUM_WARPS,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+ grad_logits_chunk = grad_logits_chunk.to(dtype)
+
+ if in_feat.requires_grad:
+ grad_in_feat[token_start_idx:token_end_idx] = (
+ grad_logits_chunk @ proj_weight_cast
+ )
+
+ if proj_weight.requires_grad:
+ torch.addmm(
+ grad_proj_weight,
+ grad_logits_chunk.T,
+ in_feat_chunk,
+ out=grad_proj_weight,
+ )
+
+ # NOTE: if reduction == "mean" we already divide by an appropriate normalization factor in the kernel so we can alway sum here
+ # print("Loss before sum: ", loss)
+ loss = loss.sum()
+ # print("Loss after sum: ", loss.item())
+ # Save data for backward
+ ctx.in_feat_requires_grad = in_feat.requires_grad
+ ctx.proj_weight_requires_grad = proj_weight.requires_grad
+
+ if proj_weight.requires_grad and in_feat.requires_grad:
+ grad_in_feat = grad_in_feat.view(bs, seqlen, hidden_dim)
+ ctx.save_for_backward(grad_in_feat, grad_proj_weight)
+ elif proj_weight.requires_grad and not in_feat.requires_grad:
+ ctx.save_for_backward(grad_proj_weight)
+ elif not proj_weight.requires_grad and in_feat.requires_grad:
+ grad_in_feat = grad_in_feat.view(bs, seqlen, hidden_dim)
+ ctx.save_for_backward(grad_in_feat)
+
+ return loss
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, grad_output):
+ grad_in_feat = grad_proj_weight = None
+ if ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:
+ grad_in_feat, grad_proj_weight = ctx.saved_tensors
+ elif not ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:
+ (grad_proj_weight,) = ctx.saved_tensors
+ elif ctx.in_feat_requires_grad and not ctx.proj_weight_requires_grad:
+ (grad_in_feat,) = ctx.saved_tensors
+
+ # Needed for gradient scaling?
+ if grad_in_feat is not None and grad_in_feat.dtype == torch.float16:
+ grad_in_feat = (grad_in_feat.to(torch.float32) * grad_output).to(
+ torch.float16
+ )
+ if grad_proj_weight is not None and grad_proj_weight.dtype == torch.float16:
+ grad_proj_weight = (grad_proj_weight.to(torch.float32) * grad_output).to(
+ torch.float16
+ )
+
+ return grad_in_feat, grad_proj_weight, None, None, None, None
+
+
+def fused_cel_linear(
+ x, proj_weight, labels, n_loop_iters=1, ignore_index=-100, reduction="mean"
+):
+ """
+ x: (bs, seqlen, hidden_dim)
+ proj_weight: (vocab_size, hidden_dim)
+ labels: (bs, seqlen)
+
+ """
+ return FusedCrossEntropyLossFunction.apply(
+ x, proj_weight, labels, n_loop_iters, ignore_index, reduction
+ )
+
+
+def fused_cel_layer(
+ hidden_states,
+ lm_head_weight,
+ labels,
+ n_loop_iters=1,
+ ignore_index=-100,
+ reduction="mean",
+):
+ # if n_loop_iters > 1, we pad the labels with an extra ignore_index token
+ # so that we when chunking logits, we can divide the original number of tokens rather than number of tokens - 1
+ # This doesn't guarantee that the number of tokens (batch_size * seqlen) is a multiple of n_loop_iters
+ # but it makes simpler in the case for packed sequences, where the number of tokens is more often a multiple of 2
+ if n_loop_iters > 1:
+ labels = labels[..., 1:].contiguous()
+ # Pad labels
+ place_holder = torch.full(
+ (labels.shape[0], 1),
+ ignore_index,
+ dtype=labels.dtype,
+ device=labels.device,
+ )
+ labels = (
+ torch.hstack([labels, place_holder]).to(hidden_states.device).contiguous()
+ )
+
+ loss = fused_cel_linear(
+ hidden_states,
+ lm_head_weight,
+ labels,
+ n_loop_iters=n_loop_iters,
+ ignore_index=ignore_index,
+ reduction=reduction,
+ )
+ else:
+ # Need to shift, since kernel assumes labels and hidden states have same bs * seqlen
+ shift_hidden_states = hidden_states[
+ ..., :-1, :
+ ].contiguous() # This is important -- MUST call contiguous, otherwise will cause downstream reshaping issues
+ shift_labels = labels[..., 1:].contiguous()
+ shift_labels = shift_labels.to(shift_hidden_states.device)
+
+ loss = fused_cel_linear(
+ shift_hidden_states,
+ lm_head_weight,
+ shift_labels,
+ n_loop_iters=n_loop_iters,
+ ignore_index=ignore_index,
+ reduction=reduction,
+ )
+
+ return loss
+
+
+class LlamaForCausalLMFusedCEL(LlamaForCausalLM):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ causal_mask: Optional[torch.Tensor] = None, # this is deprecated
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(
+ self.vocab_size // self.config.pretraining_tp, dim=0
+ )
+ logits = [
+ F.linear(hidden_states, lm_head_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ logits = torch.cat(logits, dim=-1)
+ elif hasattr(self.config, "fused_cel") and self.config.fused_cel.use_fused_cel:
+ logger.warning_once(
+ "Using fused cross entropy loss, output logits will be in None"
+ )
+ assert labels is not None, "labels must not be None to use fused CEL"
+
+ loss = fused_cel_layer(
+ hidden_states,
+ self.lm_head.weight,
+ labels,
+ n_loop_iters=self.config.fused_cel.n_loop_iters,
+ ignore_index=self.config.fused_cel.ignore_index,
+ reduction="mean",
+ )
+
+ logits = None
+
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ # print("No fused shift logits", shift_logits.mean().item())
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def patch_model(
+ model,
+ use_fused_cel=True,
+ fused_cel_n_loop_iters=1,
+ fused_cel_ignore_index=-100,
+ fused_cel_reduction="mean",
+):
+ fused_config = FusedCELConfig(
+ use_fused_cel=use_fused_cel,
+ n_loop_iters=fused_cel_n_loop_iters,
+ ignore_index=fused_cel_ignore_index,
+ reduction=fused_cel_reduction,
+ )
+ model.config.update({"fused_cel": fused_config})
+
+ return model
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index 136ceb2c..18df54aa 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -1605,6 +1605,7 @@ def patch_peft_model(
if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
+ elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
@@ -1833,5 +1834,4 @@ def for_training(model, use_gradient_checkpointing = True):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
pass
-pass
-
+pass
\ No newline at end of file
diff --git a/unsloth/utils/data.py b/unsloth/utils/data.py
new file mode 100644
index 00000000..87ad7fc4
--- /dev/null
+++ b/unsloth/utils/data.py
@@ -0,0 +1,110 @@
+from functools import partial
+
+from datasets import load_dataset
+from torch.utils.data import DataLoader, RandomSampler
+from transformers.data.data_collator import DataCollatorForLanguageModeling
+
+ALPACA_PROMPT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+### Instruction:
+{}
+
+### Input:
+{}
+
+### Response:
+{}"""
+
+
+def _formatting_prompts_func_alpaca(examples, eos_token):
+ instructions = examples["instruction"]
+ inputs = examples["input"]
+ outputs = examples["output"]
+ texts = []
+ for instruction, input, output in zip(instructions, inputs, outputs):
+ # Must add EOS_TOKEN, otherwise your generation will go on forever!
+ text = ALPACA_PROMPT.format(instruction, input, output) + eos_token
+ texts.append(text)
+ return {
+ "text": texts,
+ }
+
+
+pass
+
+FORMATTING_FUNCS: dict = {"ALPACA": _formatting_prompts_func_alpaca}
+
+
+def get_alpaca(tokenizer, batched=True, split="train"):
+ dataset = load_dataset("yahma/alpaca-cleaned", split=split)
+ dataset = dataset.map(
+ partial(_formatting_prompts_func_alpaca, eos_token=tokenizer.eos_token),
+ batched=batched,
+ )
+ return dataset
+
+
+def prepare_non_packed_dataloader(
+ tokenizer,
+ dataset,
+ dataset_text_field,
+ max_seq_length,
+ batch_size=1,
+ num_proc=1,
+ formatting_func=None,
+ add_special_tokens=True,
+ remove_unused_columns=True,
+ num_examples=10,
+):
+ use_formatting_func = formatting_func is not None and dataset_text_field is None
+ dataset = dataset.select(range(num_examples))
+
+ # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
+ def tokenize(element):
+ outputs = tokenizer(
+ element[dataset_text_field]
+ if not use_formatting_func
+ else formatting_func(element),
+ add_special_tokens=add_special_tokens,
+ truncation=True,
+ padding=False,
+ max_length=max_seq_length,
+ return_overflowing_tokens=False,
+ return_length=False,
+ )
+
+ return {
+ "input_ids": outputs["input_ids"],
+ "attention_mask": outputs["attention_mask"],
+ }
+
+ tokenized_dataset = dataset.map(
+ tokenize,
+ batched=True,
+ remove_columns=dataset.column_names if remove_unused_columns else None,
+ num_proc=num_proc,
+ batch_size=batch_size,
+ )
+
+ return tokenized_dataset
+
+
+def get_data_loader(
+ dataset, tokenizer, max_seq_length, batch_size=1, num_proc=1, num_examples=10
+):
+ tokenized_dataset = prepare_non_packed_dataloader(
+ tokenizer,
+ dataset,
+ dataset_text_field="text",
+ max_seq_length=max_seq_length,
+ batch_size=batch_size,
+ num_proc=num_proc,
+ num_examples=num_examples * batch_size,
+ )
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+ return DataLoader(
+ tokenized_dataset,
+ batch_size=batch_size,
+ collate_fn=data_collator,
+ sampler=RandomSampler(tokenized_dataset),
+ )
diff --git a/unsloth/utils/memory.py b/unsloth/utils/memory.py
new file mode 100644
index 00000000..5157a2da
--- /dev/null
+++ b/unsloth/utils/memory.py
@@ -0,0 +1,37 @@
+import gc
+
+import torch
+
+
+def get_memory_stats():
+ print(torch.cuda.memory_summary())
+
+
+def get_max_memory_reserved():
+ mem = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
+ print(f"Peak reserved memory = {mem} GB.")
+ return mem
+
+
+def get_mem_stat_keys():
+ mem_stats = torch.cuda.memory_stats_as_nested_dict()
+ return mem_stats.keys()
+
+
+def get_mem_stat(key):
+ mem_stats = torch.cuda.memory_stats_as_nested_dict()
+ try:
+ return mem_stats[key]
+ except:
+ print(
+ f"Key {key} not found in memory stats, run `get_mem_stat_keys()` to see all keys."
+ )
+ return None
+
+
+def empty_cache():
+ # Clean memory up first
+ for _ in range(3):
+ torch.cuda.empty_cache()
+ gc.collect()
+ pass
diff --git a/unsloth/utils/profiling.py b/unsloth/utils/profiling.py
new file mode 100644
index 00000000..e3dd1cd6
--- /dev/null
+++ b/unsloth/utils/profiling.py
@@ -0,0 +1,84 @@
+import dataclasses
+import json
+import os
+from datetime import datetime
+
+from transformers.trainer_callback import ProgressCallback, TrainerCallback
+from transformers.trainer_pt_utils import _secs2timedelta
+
+# Prints filename and line number when logging
+LOG_FORMAT_STR = (
+ "%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s"
+)
+
+TRAINER_PERF_ARGS = {
+ "skip_memory_metrics": False,
+ "include_num_input_tokens_seen": True,
+ "include_tokens_per_second": True,
+}
+
+
+class MetricsCallBack(TrainerCallback):
+ def __init__(self, name, verbose=False):
+ self.name = name
+ self.verbose = verbose
+
+ def metrics_format(self, metrics):
+ """
+ Reformat Trainer metrics values to a human-readable format
+
+ Args:
+ metrics (`Dict[str, float]`):
+ The metrics returned from train/evaluate/predict
+
+ Returns:
+ metrics (`Dict[str, float]`): The reformatted metrics
+ """
+
+ metrics_copy = metrics.copy()
+ for k, v in metrics_copy.items():
+ if "_mem_" in k:
+ metrics_copy[k] = f"{ v >> 20 }MB"
+ elif "_runtime" in k:
+ metrics_copy[k] = _secs2timedelta(v)
+ elif k == "total_flos":
+ metrics_copy[k] = f"{ int(v) >> 30 }GF"
+ elif isinstance(metrics_copy[k], float):
+ metrics_copy[k] = round(v, 4)
+
+ return metrics_copy
+
+ def save_state(self, output_dir, state, append_step=False):
+ # Format metrics (last entry of log_history)
+ log_history = state.log_history
+ metrics = self.metrics_format(log_history[-1])
+ log_history[-1] = metrics
+ state.log_history = log_history
+
+ # Save state
+ json_string = (
+ json.dumps(dataclasses.asdict(state), indent=2, sort_keys=True) + "\n"
+ )
+
+ date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ step = "-" + str(state.global_step) if append_step else ""
+ json_path = os.path.join(output_dir, f"{date_str}-{self.name}{step}.json")
+ with open(json_path, "w", encoding="utf-8") as f:
+ f.write(json_string)
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if self.verbose:
+ logs_formatted = self.metrics_format(logs)
+ k_width = max(len(str(x)) for x in logs_formatted.keys())
+ v_width = max(len(str(x)) for x in logs_formatted.values())
+ print("Global Step: ", state.global_step)
+ for key in sorted(logs_formatted.keys()):
+ print(f" {key: <{k_width}} = {logs_formatted[key]:>{v_width}}")
+ else:
+ return
+
+ def on_train_end(self, args, state, control, **kwargs):
+ self.save_state(args.output_dir, state)
+
+
+# super().on_train_end(args, state, control, **kwargs)
diff --git a/unsloth/utils/testing.py b/unsloth/utils/testing.py
new file mode 100644
index 00000000..e4cff568
--- /dev/null
+++ b/unsloth/utils/testing.py
@@ -0,0 +1,14 @@
+import torch
+
+
+def check_all(expected, actual, names, atol=1e-6, rtol=1e-6, verbose=True):
+ if verbose:
+ print()
+ for name, e, a in zip(names, expected, actual):
+ if verbose:
+ print(
+ f"{name}: {torch.allclose(e, a, atol=atol, rtol=rtol)}: {(e - a).abs().max()}"
+ )
+ assert torch.allclose(
+ e, a, atol=atol, rtol=rtol
+ ), f"{name}: {(e - a).abs().max()}"