From 55959a3fc7b50045388651a65af159c8d2f03734 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 20 Oct 2024 06:53:35 -0700 Subject: [PATCH] allow for auto-calling ema after optimizer step using `register_step_post_hook` --- ema_pytorch/ema_pytorch.py | 10 +++++++++- setup.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index e0d6826..7d31e56 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -185,13 +185,21 @@ def init_ema( self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} + def add_to_optimizer_post_step_hook(self, optimizer): + assert hasattr(optimizer, 'register_step_post_hook') + + def hook(*_): + self.update() + + return optimizer.register_step_post_hook(hook) + @property def model(self): return self.online_model if self.include_online_model else self.online_model[0] def eval(self): return self.ema_model.eval() - + def restore_ema_model_device(self): device = self.initted.device self.ema_model.to(device) diff --git a/setup.py b/setup.py index 9bb79bc..adb1b9f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.0', + version = '0.7.2', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',