Skip to content

Commit

Permalink
Merge pull request #281 from kozistr/feature/support-4bit-optimizer
Browse files Browse the repository at this point in the history
[Feature] Support 8/4bit, fp8 optimizers
  • Loading branch information
kozistr authored Oct 17, 2024
2 parents 1687e37 + f3dcf8e commit 20ed84f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 9 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand All @@ -27,8 +27,8 @@ So, please double-check the license before using it at your work.
$ pip3 install pytorch-optimizer
```

From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch` optimizers respectively!
please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer)
From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch`, `torchao` optimizers respectively!
please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer), [torchao installation](https://github.com/pytorch/ao?tab=readme-ov-file#installation)
before installing it.

From `v3.0.0`, drop `Python 3.7` support. However, you can still use this package with `Python 3.7` by installing with `--ignore-requires-python` option.
Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321)
* Support `AdEMAMix` variants. (#276)
* `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit`
* Support 8/4bit, fp8 optimizers. (#208, #281)
* `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`.

### Bug

Expand Down
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand All @@ -27,8 +27,8 @@ So, please double-check the license before using it at your work.
$ pip3 install pytorch-optimizer
```

From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch` optimizers respectively!
please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer)
From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch`, `torchao` optimizers respectively!
please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer), [torchao installation](https://github.com/pytorch/ao?tab=readme-ov-file#installation)
before installing it.

From `v3.0.0`, drop `Python 3.7` support. However, you can still use this package with `Python 3.7` by installing with `--ignore-requires-python` option.
Expand Down
28 changes: 25 additions & 3 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@

HAS_BNB: bool = find_spec('bitsandbytes') is not None
HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None
HAS_TORCHAO: bool = find_spec('torchao') is not None

OPTIMIZER_LIST: List[OPTIMIZER] = [
AdamW,
Expand Down Expand Up @@ -323,19 +324,40 @@ def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')


def load_ao_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
r"""load TorchAO optimizer instance."""
from torchao.prototype import low_bit_optim

if 'adamw8bit' in optimizer:
return low_bit_optim.AdamW8bit
if 'adamw4bit' in optimizer:
return low_bit_optim.AdamW4bit
if 'adamwfp8' in optimizer:
return low_bit_optim.AdamWFp8

raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')


def load_optimizer(optimizer: str) -> OPTIMIZER:
optimizer: str = optimizer.lower()

if optimizer.startswith('bnb'):
if HAS_BNB and torch.cuda.is_available():
return load_bnb_optimizer(optimizer) # pragma: no cover
raise ImportError(f'[-] bitsandbytes and CUDA required for the optimizer {optimizer}')
raise ImportError(f'bitsandbytes and CUDA required for the optimizer {optimizer}')
if optimizer.startswith('q_galore'):
if HAS_Q_GALORE and torch.cuda.is_available():
return load_q_galore_optimizer(optimizer) # pragma: no cover
raise ImportError(f'[-] bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}')
raise ImportError(f'bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}')
if optimizer.startswith('torchao'):
if HAS_TORCHAO and torch.cuda.is_available():
return load_ao_optimizer(optimizer) # pragma: no cover
raise ImportError(
f'torchao required for the optimizer {optimizer}. '
'usage: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#usage'
)
if optimizer not in OPTIMIZERS:
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
raise NotImplementedError(f'not implemented optimizer : {optimizer}')

return OPTIMIZERS[optimizer]

Expand Down
5 changes: 5 additions & 0 deletions tests/test_create_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,8 @@ def test_bnb_optimizer():
def test_q_galore_optimizer():
with pytest.raises(ImportError):
load_optimizer('q_galore_adamw8bit')


def test_torchao_optimizer():
with pytest.raises(ImportError):
load_optimizer('torchao_adamw4bit')

0 comments on commit 20ed84f

Please sign in to comment.