diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index e0d6826..7f523c4 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -185,13 +185,21 @@ def init_ema( self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} + def add_to_optimizer_post_step_hook(self, optimizer): + assert hasattr(optimizer, 'register_step_post_hook') + + def hook(*_): + self.update() + + optimizer.register_step_post_hook(hook) + @property def model(self): return self.online_model if self.include_online_model else self.online_model[0] def eval(self): return self.ema_model.eval() - + def restore_ema_model_device(self): device = self.initted.device self.ema_model.to(device) diff --git a/setup.py b/setup.py index 9bb79bc..08ed4c8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.0', + version = '0.7.1', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',