Skip to content

Commit

Permalink
fix(optimizer): add closure support (needed for PL)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Nov 23, 2022
1 parent 2c713a3 commit 90ed2d4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
version='0.0.8',
version='0.0.9',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
Expand Down
2 changes: 1 addition & 1 deletion truegrad/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(ctx, out, fn, args, kwargs) -> torch.Tensor:
@staticmethod
def backward(ctx, dy: torch.Tensor) -> Tuple[None, None, None, None]:
def _square(x: Union[torch.Tensor, None]):
if isinstance(x, TrueGradParameter):
if isinstance(x, torch.nn.Parameter):
x = x.data
if not isinstance(x, torch.Tensor) or not torch.is_floating_point(x):
return x
Expand Down
6 changes: 6 additions & 0 deletions truegrad/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def __init__(self, params, lr: float = 1e-3,

@torch.no_grad()
def step(self, closure=None):
if closure is None:
loss = None
else:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if len(group["betas"]) == 2:
beta1, beta2 = group["betas"]
Expand Down Expand Up @@ -74,3 +79,4 @@ def step(self, closure=None):
alpha = alpha * adam_update.norm() / update.norm()

p.add_(update, alpha=alpha)
return loss

0 comments on commit 90ed2d4

Please sign in to comment.