From 90ed2d4216417e6aba2dd29a83b483f572c19777 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 23 Nov 2022 20:23:37 +0100 Subject: [PATCH] fix(optimizer): add closure support (needed for PL) --- setup.py | 2 +- truegrad/nn.py | 2 +- truegrad/optim.py | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index aed8efd..86d44d1 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='truegrad', license='BSD', description='PyTorch interface for TrueGrad-AdamW', - version='0.0.8', + version='0.0.9', long_description=README, url='https://github.com/clashluke/truegrad', packages=setuptools.find_packages(), diff --git a/truegrad/nn.py b/truegrad/nn.py index 049cb9c..4b529e8 100644 --- a/truegrad/nn.py +++ b/truegrad/nn.py @@ -143,7 +143,7 @@ def forward(ctx, out, fn, args, kwargs) -> torch.Tensor: @staticmethod def backward(ctx, dy: torch.Tensor) -> Tuple[None, None, None, None]: def _square(x: Union[torch.Tensor, None]): - if isinstance(x, TrueGradParameter): + if isinstance(x, torch.nn.Parameter): x = x.data if not isinstance(x, torch.Tensor) or not torch.is_floating_point(x): return x diff --git a/truegrad/optim.py b/truegrad/optim.py index 58345fd..466addf 100644 --- a/truegrad/optim.py +++ b/truegrad/optim.py @@ -16,6 +16,11 @@ def __init__(self, params, lr: float = 1e-3, @torch.no_grad() def step(self, closure=None): + if closure is None: + loss = None + else: + with torch.enable_grad(): + loss = closure() for group in self.param_groups: if len(group["betas"]) == 2: beta1, beta2 = group["betas"] @@ -74,3 +79,4 @@ def step(self, closure=None): alpha = alpha * adam_update.norm() / update.norm() p.add_(update, alpha=alpha) + return loss