Skip to content

Commit

Permalink
MNT: DiffeoOptimzer.step more idiomatic
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Sep 29, 2024
1 parent 5a869e0 commit 61da390
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions diffeopt/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
import torch
from .group.representation import Perturbation

class DiffeoOptimizer(torch.optim.Optimizer, ABC):
def step(self, closure: Callable | None=None):
loss = None
if closure is not None:
from torch.optim import Optimizer # type: ignore[attr-defined]

class DiffeoOptimizer(Optimizer, ABC):

@torch.no_grad()
def step(self, closure:Callable[[], float] | None = None) -> float | None: # type: ignore[override]
loss: float | None = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
cometric = group['cometric']
for p in group['params']:
momentum = p.grad
velocity = cometric(momentum)
self._update_parameter(p, velocity, group)

for group in self.param_groups:
cometric = group['cometric']
for p in group['params']:
momentum = p.grad
velocity = cometric(momentum)

Check failure on line 21 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F841)

diffeopt/optim.py:21:17: F841 Local variable `velocity` is assigned to but never used

Check failure on line 21 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F841)

diffeopt/optim.py:21:17: F841 Local variable `velocity` is assigned to but never used

Check failure on line 21 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F841)

diffeopt/optim.py:21:17: F841 Local variable `velocity` is assigned to but never used

Check failure on line 21 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (F841)

diffeopt/optim.py:21:17: F841 Local variable `velocity` is assigned to but never used

return loss

@abstractmethod
def _update_parameter(self, parameter: Perturbation, velocity: torch.Tensor, group: dict[str, Any]) -> None:
Expand Down

0 comments on commit 61da390

Please sign in to comment.