diff --git a/poetry.lock b/poetry.lock index acfbeeff..87ddfb20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -500,32 +500,32 @@ files = [ [[package]] name = "pathspec" -version = "0.11.1" +version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.7" files = [ - {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, - {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, + {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, + {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] [[package]] name = "platformdirs" -version = "3.9.1" +version = "3.10.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"}, - {file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"}, + {file = "platformdirs-3.10.0-py3-none-any.whl", hash = "sha256:d7c24979f292f916dc9cbf8648319032f551ea8c49a4c9bf2fb556a02070ec1d"}, + {file = "platformdirs-3.10.0.tar.gz", hash = "sha256:b45696dab2d7cc691a3226759c0d3b00c47c8b6e293d96f6436f733303f77f6d"}, ] [package.dependencies] -typing-extensions = {version = ">=4.6.3", markers = "python_version < \"3.8\""} +typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.8\""} [package.extras] -docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] [[package]] name = "pluggy" @@ -588,28 +588,28 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale [[package]] name = "ruff" -version = "0.0.278" +version = "0.0.284" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.278-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:1a90ebd8f2a554db1ee8d12b2f3aa575acbd310a02cd1a9295b3511a4874cf98"}, - {file = "ruff-0.0.278-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:38ca1c0c8c1221fe64c0a66784c91501d09a8ed02a4dbfdc117c0ce32a81eefc"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c62a0bde4d20d087cabce2fa8b012d74c2e985da86d00fb3359880469b90e31"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7545bb037823cd63dca19280f75a523a68bd3e78e003de74609320d6822b5a52"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb380d2d6fdb60656a0b5fa78305535db513fc72ce11f4532cc1641204ef380"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d11149c7b186f224f2055e437a030cd83b164a43cc0211314c33ad1553ed9c4c"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:666e739fb2685277b879d493848afe6933e3be30d40f41fe0e571ad479d57d77"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ec8b0469b54315803aaf1fbf9a37162a3849424cab6182496f972ad56e0ea702"}, - {file = "ruff-0.0.278-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c25b96602695a147d62a572865b753ef56aff1524abab13b9436724df30f9bd7"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a48621f5f372d5019662db5b3dbfc5f1450f927683d75f1153fe0ebf20eb9698"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1078125123a3c68e92463afacedb7e41b15ccafc09e510c6c755a23087afc8de"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3ce0d620e257b4cad16e2f0c103b2f43a07981668a3763380542e8a131d11537"}, - {file = "ruff-0.0.278-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1cae4c07d334eb588f171f1363fa89a8911047eb93184276be11a24dbbc996c7"}, - {file = "ruff-0.0.278-py3-none-win32.whl", hash = "sha256:70d39f5599d8449082ab8ce542fa98e16413145eb411dd1dc16575b44565d52d"}, - {file = "ruff-0.0.278-py3-none-win_amd64.whl", hash = "sha256:e131595ab7f4ce61a1650463bd2fe304b49e7d0deb0dfa664b92817c97cdba5f"}, - {file = "ruff-0.0.278-py3-none-win_arm64.whl", hash = "sha256:737a0cfb6c36aaa92d97a46957dfd5e55329299074ad06ed12663b98e0c6fc82"}, - {file = "ruff-0.0.278.tar.gz", hash = "sha256:1a9f1d925204cfba81b18368b7ac943befcfccc3a41e170c91353b674c6b7a66"}, + {file = "ruff-0.0.284-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:8b949084941232e2c27f8d12c78c5a6a010927d712ecff17231ee1a8371c205b"}, + {file = "ruff-0.0.284-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:a3930d66b35e4dc96197422381dff2a4e965e9278b5533e71ae8474ef202fab0"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d1f7096038961d8bc3b956ee69d73826843eb5b39a5fa4ee717ed473ed69c95"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bcaf85907fc905d838f46490ee15f04031927bbea44c478394b0bfdeadc27362"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3660b85a9d84162a055f1add334623ae2d8022a84dcd605d61c30a57b436c32"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0a3218458b140ea794da72b20ea09cbe13c4c1cdb7ac35e797370354628f4c05"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2fe880cff13fffd735387efbcad54ba0ff1272bceea07f86852a33ca71276f4"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1d098ea74d0ce31478765d1f8b4fbdbba2efc532397b5c5e8e5ea0c13d7e5ae"}, + {file = "ruff-0.0.284-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c79ae3308e308b94635cd57a369d1e6f146d85019da2fbc63f55da183ee29b"}, + {file = "ruff-0.0.284-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f86b2b1e7033c00de45cc176cf26778650fb8804073a0495aca2f674797becbb"}, + {file = "ruff-0.0.284-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e37e086f4d623c05cd45a6fe5006e77a2b37d57773aad96b7802a6b8ecf9c910"}, + {file = "ruff-0.0.284-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d29dfbe314e1131aa53df213fdfea7ee874dd96ea0dd1471093d93b59498384d"}, + {file = "ruff-0.0.284-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:88295fd649d0aa1f1271441df75bf06266a199497afd239fd392abcfd75acd7e"}, + {file = "ruff-0.0.284-py3-none-win32.whl", hash = "sha256:735cd62fccc577032a367c31f6a9de7c1eb4c01fa9a2e60775067f44f3fc3091"}, + {file = "ruff-0.0.284-py3-none-win_amd64.whl", hash = "sha256:f67ed868d79fbcc61ad0fa034fe6eed2e8d438d32abce9c04b7c4c1464b2cf8e"}, + {file = "ruff-0.0.284-py3-none-win_arm64.whl", hash = "sha256:1292cfc764eeec3cde35b3a31eae3f661d86418b5e220f5d5dba1c27a6eccbb6"}, + {file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"}, ] [[package]] @@ -777,4 +777,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "9500ef58413a994b1cf0f2d988e7b3e9db8ddf65228e9bbc10d0213fec860bbf" +content-hash = "d87271d564554c366418efb8ef44ed31917a23a2ae0af7182fd37a8172e4c5a7" diff --git a/pyproject.toml b/pyproject.toml index 015ce7ab..7e0aba90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ black = [ { version = "==23.3.0", python = ">=3.7,<3.8" }, { version = "^23.7.0", python = ">=3.8"} ] -ruff = "^0.0.278" +ruff = "^0.0.284" pytest = "^7.4.0" pytest-cov = "^4.1.0" diff --git a/pytorch_optimizer/optimizer/lomo.py b/pytorch_optimizer/optimizer/lomo.py index 6334df36..784998c9 100644 --- a/pytorch_optimizer/optimizer/lomo.py +++ b/pytorch_optimizer/optimizer/lomo.py @@ -46,7 +46,7 @@ def __init__( self.grad_norms: List[torch.Tensor] = [] self.clip_coef: Optional[float] = None - p0: torch.Tensor = list(self.model.parameters())[0] + p0: torch.Tensor = next(iter(self.model.parameters())) self.grad_func: Callable[[Any], Any] = ( self.fuse_update_zero3() if hasattr(p0, 'ds_tensor') else self.fuse_update() diff --git a/pytorch_optimizer/optimizer/rotograd.py b/pytorch_optimizer/optimizer/rotograd.py index 1d0f2920..3d3e9f12 100644 --- a/pytorch_optimizer/optimizer/rotograd.py +++ b/pytorch_optimizer/optimizer/rotograd.py @@ -219,7 +219,7 @@ def to(self, *args, **kwargs): self.backbone.to(*args, **kwargs) for head in self.heads: head.to(*args, **kwargs) - return super(RotateOnly, self).to(*args, **kwargs) + return super().to(*args, **kwargs) def train(self, mode: bool = True) -> nn.Module: super().train(mode) @@ -284,7 +284,7 @@ def backward(self, losses: Sequence[torch.Tensor], backbone_loss=None, **kwargs) if not self.training: raise AssertionError('Backward should only be called when training') - if self.iteration_counter == 0 or self.iteration_counter == self.burn_in_period: + if self.iteration_counter in (0, self.burn_in_period): for i, loss in enumerate(losses): self.initial_losses[i] = loss.item() diff --git a/requirements-dev.txt b/requirements-dev.txt index 8c55d441..340f4675 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,12 +19,12 @@ networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" packaging==23.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" -pathspec==0.11.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" -platformdirs==3.9.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" +pathspec==0.11.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" +platformdirs==3.10.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" pluggy==1.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" pytest-cov==4.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" pytest==7.4.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" -ruff==0.0.278 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" +ruff==0.0.284 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0" tomli==2.0.1 ; python_full_version >= "3.7.2" and python_version < "3.11" torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8" diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 81e8e5c8..c8346b37 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -476,6 +476,6 @@ def test_lomo_optimizer(precision, environment): if precision == 16: optimizer.clip_coef = 0.9 - loss = sphere_loss(list(model.parameters())[0]) + loss = sphere_loss(next(iter(model.parameters()))) optimizer.grad_norm(loss) optimizer.fused_backward(loss, lr=0.1)