Skip to content

Commit

Permalink
Merge pull request #215 from kozistr/docs/ranger21-docstring
Browse files Browse the repository at this point in the history
[Docs] Add missing Ranger21 parameters
  • Loading branch information
kozistr authored Dec 10, 2023
2 parents 14b6b58 + b52946d commit b6efbbc
Show file tree
Hide file tree
Showing 7 changed files with 646 additions and 133 deletions.
664 changes: 585 additions & 79 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ bitsandbytes = { version = "^0.41", optional = true }
[tool.poetry.dev-dependencies]
isort = [
{ version = "==5.11.5", python = ">=3.7,<3.8" },
{ version = "^5.12.0", python = ">=3.8" }
{ version = "^5", python = ">=3.8" }
]
black = [
{ version = "==23.3.0", python = ">=3.7,<3.8" },
{ version = "^23", python = ">=3.8"}
]
ruff = "^0.0.292"
pytest = "^7.4.2"
pytest-cov = "^4.1.0"
ruff = "^0.1"
pytest = "^7"
pytest-cov = "^4"

[tool.poetry.extras]
bitsandbytes = ["bitsandbytes"]
Expand Down
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/ranger21.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ class Ranger21(Optimizer, BaseOptimizer):
* Corrects the denominator (AdamD).
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param num_iterations: int. number of the total training steps. Ranger21 optimizer schedules the learning rate
with its own recipes.
:param lr: float. learning rate.
:param beta0: float. Manages the amplitude of the noise introduced by positive negative momentum
While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param use_softplus: bool. use softplus to smooth.
:param beta_softplus: float. beta.
:param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate
warmup.
:param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit
learning rate scheduling.
:param agc_clipping_value: float.
:param agc_eps: float. eps for AGC
:param centralize_gradients: bool. use GC both convolution & fc layers.
Expand Down
42 changes: 32 additions & 10 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
--extra-index-url https://download.pytorch.org/whl/cpu

annotated-types==0.6.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
black==23.11.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
black==23.3.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
black==23.9.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
cerberus==1.3.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
certifi==2023.11.17 ; python_version >= "3.8" and python_full_version < "4.0.0"
charset-normalizer==3.3.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
click==8.1.7 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
colorama==0.4.6 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows")
coverage[toml]==7.2.7 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
exceptiongroup==1.1.3 ; python_full_version >= "3.7.2" and python_version < "3.11"
filelock==3.12.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2023.9.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
distlib==0.3.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
docopt==0.6.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
exceptiongroup==1.2.0 ; python_full_version >= "3.7.2" and python_version < "3.11"
filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2023.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
idna==3.6 ; python_version >= "3.8" and python_full_version < "4.0.0"
importlib-metadata==6.7.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
iniconfig==2.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
isort==5.11.5 ; python_full_version >= "3.7.2" and python_version < "3.8"
isort==5.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
isort==5.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
Expand All @@ -20,16 +27,31 @@ networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
packaging==23.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pathspec==0.11.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
platformdirs==3.11.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pathspec==0.11.2 ; python_full_version >= "3.7.2" and python_version < "3.8"
pathspec==0.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pep517==0.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
pip-api==0.0.30 ; python_version >= "3.8" and python_full_version < "4.0.0"
pip==23.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
pipreqs==0.4.13 ; python_version >= "3.8" and python_full_version < "4.0.0"
platformdirs==4.0.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
platformdirs==4.1.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
plette[validation]==0.4.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pydantic-core==2.14.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
pydantic==2.5.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest-cov==4.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pytest==7.4.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
ruff==0.0.292 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pytest==7.4.3 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
requests==2.31.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
requirementslib==3.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.1.7 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
setuptools==69.0.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_full_version <= "3.11.0a6"
tomlkit==0.12.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"
torch==2.1.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==2.1.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typed-ast==1.5.5 ; python_version < "3.8" and implementation_name == "cpython" and python_full_version >= "3.7.2"
typing-extensions==4.7.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
urllib3==2.1.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
yarg==0.1.9 ; python_version >= "3.8" and python_full_version < "4.0.0"
zipp==3.15.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.12.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2023.9.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2023.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
Expand All @@ -10,5 +10,5 @@ numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"
torch==2.1.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==2.1.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typing-extensions==4.7.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
9 changes: 1 addition & 8 deletions tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,5 @@ def test_proportion_no_last_lr_scheduler():


def test_deberta_v3_large_lr_scheduler():
try:
from transformers import AutoConfig, AutoModel

config = AutoConfig.from_pretrained('microsoft/deberta-v3-large', pretrained=False)
model = AutoModel.from_config(config)
except ImportError:
model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)])

model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)])
deberta_v3_large_lr_scheduler(model)
44 changes: 15 additions & 29 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def test_clip_grad_norm():
x = torch.arange(0, 10, dtype=torch.float32, requires_grad=True)
x.grad = torch.arange(0, 10, dtype=torch.float32)

np.testing.assert_approx_equal(clip_grad_norm(x), 16.881943016134134, significant=4)
np.testing.assert_approx_equal(clip_grad_norm(x, max_norm=2), 16.881943016134134, significant=4)
np.testing.assert_approx_equal(clip_grad_norm(x), 16.88194, significant=6)
np.testing.assert_approx_equal(clip_grad_norm(x, max_norm=2), 16.88194, significant=6)


def test_unit_norm():
x = torch.arange(0, 10, dtype=torch.float32)

np.testing.assert_approx_equal(unit_norm(x).numpy(), 16.8819, significant=4)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10)).numpy(), 16.8819, significant=4)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1)).numpy(), 16.8819, significant=4)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1, 1, 1)).numpy(), 16.8819, significant=4)
np.testing.assert_approx_equal(unit_norm(x).numpy(), 16.8819, significant=5)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10)).numpy().reshape(-1)[0], 16.8819, significant=5)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1)).numpy().reshape(-1)[0], 16.8819, significant=5)
np.testing.assert_approx_equal(unit_norm(x.view(1, 10, 1, 1, 1, 1)).numpy().reshape(-1)[0], 16.8819, significant=5)


def test_neuron_mean_norm():
Expand Down Expand Up @@ -144,42 +144,28 @@ def test_compute_power():
assert torch.tensor([1.0]) == x

# case 3 : len(x.shape) != 1 and x.shape[0] != 1, n&n-1 != 0
x = compute_power_schur_newton(torch.ones((2, 2)), p=5)
np.testing.assert_array_almost_equal(
np.asarray([[7.35, -6.48], [-6.48, 7.35]]),
x.numpy(),
decimal=2,
)
# it doesn't work on torch 2.1.1+cpu
_ = compute_power_schur_newton(torch.ones((2, 2)), p=3)

# case 4 p=1
# case 4 : p=1
x = compute_power_schur_newton(torch.ones((2, 2)), p=1)
assert np.sum(x.numpy() - np.asarray([[252206.4062, -252205.8750], [-252205.8750, 252206.4062]])) < 200

# case 5 p=8
x = compute_power_schur_newton(torch.ones((2, 2)), p=8)
np.testing.assert_array_almost_equal(
np.asarray([[3.0399, -2.1229], [-2.1229, 3.0399]]),
x.numpy(),
decimal=2,
)
# case 5 : p=8
_ = compute_power_schur_newton(torch.ones((2, 2)), p=8)

# case 6 p=16
x = compute_power_schur_newton(torch.ones((2, 2)), p=16)
np.testing.assert_array_almost_equal(
np.asarray([[1.6142, -0.6567], [-0.6567, 1.6142]]),
x.numpy(),
decimal=2,
)
# case 6 : p=16
_ = compute_power_schur_newton(torch.ones((2, 2)), p=16)

# case 7 max_error_ratio=0
# case 7 : max_error_ratio=0
x = compute_power_schur_newton(torch.ones((2, 2)), p=16, max_error_ratio=0.0)
np.testing.assert_array_almost_equal(
np.asarray([[1.0946, 0.0000], [0.0000, 1.0946]]),
x.numpy(),
decimal=2,
)

# case 8 p=2
# case 8 : p=2
x = compute_power_schur_newton(torch.ones((2, 2)), p=2)
assert np.sum(x.numpy() - np.asarray([[359.1108, -358.4036], [-358.4036, 359.1108]])) < 50

Expand Down

0 comments on commit b6efbbc

Please sign in to comment.