diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index bdb84d8..a673bb5 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Set, Tuple from copy import deepcopy from functools import partial @@ -7,8 +8,6 @@ from torch import nn, Tensor from torch.nn import Module -from typing import Set - def exists(val): return val is not None @@ -60,7 +59,8 @@ def __init__( ignore_startswith_names: Set[str] = set(), include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally) allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor - use_foreach = False + use_foreach = False, + forward_method_names: Tuple[str, ...] = () ): super().__init__() self.beta = beta @@ -91,6 +91,12 @@ def __init__( for p in self.ema_model.parameters(): p.detach_() + # forwarding methods + + for forward_method_name in forward_method_names: + fn = getattr(self.ema_model, forward_method_name) + setattr(self, forward_method_name, fn) + # parameter and buffer names self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} diff --git a/setup.py b/setup.py index e62b25a..524f8c4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.5.1', + version = '0.5.2', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',