Skip to content

Commit

Permalink
Merge pull request #276 from kozistr/update/bitsandbytes
Browse files Browse the repository at this point in the history
[Deps] Update `bitsandbytes` to 0.44.0
  • Loading branch information
kozistr authored Sep 24, 2024
2 parents 2b4562c + 02a135b commit 3cdf496
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 29 deletions.
2 changes: 2 additions & 0 deletions docs/changelogs/v3.2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@

* Implement `SOAP` optimizer. (#275)
* [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321)
* Support `AdEMAMix` variants. (#276)
* `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit`
102 changes: 78 additions & 24 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ classifiers = [

[tool.poetry.dependencies]
python = ">=3.8,<4.0.0"
numpy = { version = "*", python = ">=3.8" }
numpy = [
{ version = ">1.24.4", python = ">=3.9" },
{ version = "<=1.24.4", python = "<3.9" },
]
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
bitsandbytes = { version = "^0.43", optional = true }
bitsandbytes = { version = "^0.44", optional = true }

[tool.poetry.dev-dependencies]
isort = { version = "^5", python = ">=3.8" }
Expand Down
9 changes: 9 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
return optim.RMSprop32bit
if 'sgd32bit' in optimizer:
return optim.SGD32bit
if 'ademamix8bit' in optimizer:
return optim.AdEMAMix8bit
if 'ademamix32bit' in optimizer:
return optim.AdEMAMix32bit
if 'paged_ademamix8bit' in optimizer:
return optim.PagedAdEMAMix8bit
if 'paged_ademamix32bit' in optimizer:
return optim.PagedAdEMAMix32bit

raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')


Expand Down
5 changes: 3 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.9"
numpy==2.0.2 ; python_version >= "3.9" and python_full_version < "4.0.0"
packaging==24.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
platformdirs==4.3.6 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest==8.3.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.6.6 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.6.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.9"
numpy==2.0.2 ; python_version >= "3.9" and python_full_version < "4.0.0"
sympy==1.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"

0 comments on commit 3cdf496

Please sign in to comment.