Skip to content

Commit

Permalink
Merge pull request #199 from kozistr/refactor/format
Browse files Browse the repository at this point in the history
[Refactor] Fix styles
  • Loading branch information
kozistr authored Aug 12, 2023
2 parents aaaf303 + a13a8c1 commit 9314b58
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 36 deletions.
56 changes: 28 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ black = [
{ version = "==23.3.0", python = ">=3.7,<3.8" },
{ version = "^23.7.0", python = ">=3.8"}
]
ruff = "^0.0.278"
ruff = "^0.0.284"
pytest = "^7.4.0"
pytest-cov = "^4.1.0"

Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/lomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.grad_norms: List[torch.Tensor] = []
self.clip_coef: Optional[float] = None

p0: torch.Tensor = list(self.model.parameters())[0]
p0: torch.Tensor = next(iter(self.model.parameters()))

self.grad_func: Callable[[Any], Any] = (
self.fuse_update_zero3() if hasattr(p0, 'ds_tensor') else self.fuse_update()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/rotograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def to(self, *args, **kwargs):
self.backbone.to(*args, **kwargs)
for head in self.heads:
head.to(*args, **kwargs)
return super(RotateOnly, self).to(*args, **kwargs)
return super().to(*args, **kwargs)

def train(self, mode: bool = True) -> nn.Module:
super().train(mode)
Expand Down Expand Up @@ -284,7 +284,7 @@ def backward(self, losses: Sequence[torch.Tensor], backbone_loss=None, **kwargs)
if not self.training:
raise AssertionError('Backward should only be called when training')

if self.iteration_counter == 0 or self.iteration_counter == self.burn_in_period:
if self.iteration_counter in (0, self.burn_in_period):
for i, loss in enumerate(losses):
self.initial_losses[i] = loss.item()

Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
packaging==23.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pathspec==0.11.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
platformdirs==3.9.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pathspec==0.11.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
platformdirs==3.10.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pluggy==1.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pytest-cov==4.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
pytest==7.4.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
ruff==0.0.278 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
ruff==0.0.284 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_version < "3.11"
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,6 @@ def test_lomo_optimizer(precision, environment):
if precision == 16:
optimizer.clip_coef = 0.9

loss = sphere_loss(list(model.parameters())[0])
loss = sphere_loss(next(iter(model.parameters())))
optimizer.grad_norm(loss)
optimizer.fused_backward(loss, lr=0.1)

0 comments on commit 9314b58

Please sign in to comment.