From 61da39077ba8b2dbcd6ba2ee8e82e335ffe18cce Mon Sep 17 00:00:00 2001 From: Olivier Verdier Date: Sun, 29 Sep 2024 15:56:10 +0200 Subject: [PATCH] MNT: DiffeoOptimzer.step more idiomatic --- diffeopt/optim.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/diffeopt/optim.py b/diffeopt/optim.py index 3510e52..2fba86c 100644 --- a/diffeopt/optim.py +++ b/diffeopt/optim.py @@ -3,17 +3,24 @@ import torch from .group.representation import Perturbation -class DiffeoOptimizer(torch.optim.Optimizer, ABC): - def step(self, closure: Callable | None=None): - loss = None - if closure is not None: +from torch.optim import Optimizer # type: ignore[attr-defined] + +class DiffeoOptimizer(Optimizer, ABC): + + @torch.no_grad() + def step(self, closure:Callable[[], float] | None = None) -> float | None: # type: ignore[override] + loss: float | None = None + if closure is not None: + with torch.enable_grad(): loss = closure() - for group in self.param_groups: - cometric = group['cometric'] - for p in group['params']: - momentum = p.grad - velocity = cometric(momentum) - self._update_parameter(p, velocity, group) + + for group in self.param_groups: + cometric = group['cometric'] + for p in group['params']: + momentum = p.grad + velocity = cometric(momentum) + + return loss @abstractmethod def _update_parameter(self, parameter: Perturbation, velocity: torch.Tensor, group: dict[str, Any]) -> None: