Skip to content

Commit

Permalink
manual mode for PostHocEMA , contributed by @kalekundert
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2024
1 parent 1700597 commit 4941807
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions ema_pytorch/post_hoc_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def default(val, d):
def first(arr):
return arr[0]

def divisible_by(num, den):
return (num % den) == 0

def get_module_device(m: Module):
return next(m.parameters()).device

Expand Down Expand Up @@ -337,9 +340,11 @@ def update(self):
for ema_model in self.ema_models:
ema_model.update()

if not (self.checkpoint_every_num_steps == 'manual'):
if not (self.step.item() % self.checkpoint_every_num_steps):
self.checkpoint()
if self.checkpoint_every_num_steps == 'manual':
return

if divisible_by(self.step.item(), self.checkpoint_every_num_steps):
self.checkpoint()

def checkpoint(self):
step = self.step.item()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.7.5',
version = '0.7.6',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 4941807

Please sign in to comment.