diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 69def05..af2b914 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -14,16 +14,28 @@ def exists(val): def get_module_device(m: Module): return next(m.parameters()).device -def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False): +def maybe_coerce_dtype(t, dtype): + if t.dtype == dtype: + return t + + return t.to(dtype) + +def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False): if auto_move_device: src = src.to(tgt.device) + if coerce_dtype: + src = maybe_coerce_dtype(src, tgt.dtype) + tgt.copy_(src) -def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False): +def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False): if auto_move_device: src = src.to(tgt.device) + if coerce_dtype: + src = maybe_coerce_dtype(src, tgt.dtype) + tgt.lerp_(src, weight) class EMA(Module): @@ -64,7 +76,8 @@ def __init__( allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor use_foreach = False, forward_method_names: Tuple[str, ...] = (), - move_ema_to_online_device = False + move_ema_to_online_device = False, + coerce_dtype = False ): super().__init__() self.beta = beta @@ -108,8 +121,8 @@ def __init__( # tensor update functions - self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices) - self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices) + self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype) + self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype) # updating hyperparameters @@ -130,6 +143,10 @@ def __init__( self.allow_different_devices = allow_different_devices + # whether to coerce dtype when copy or lerp from online to EMA model + + self.coerce_dtype = coerce_dtype + # whether to move EMA model to online model device automatically self.move_ema_to_online_device = move_ema_to_online_device @@ -279,6 +296,10 @@ def update_moving_average(self, ma_model, current_model): tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy] tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp] + if self.coerce_dtype: + tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy] + tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp] + if len(tensors_to_copy) > 0: tgt_copy, src_copy = zip(*tensors_to_copy) torch._foreach_copy_(tgt_copy, src_copy) diff --git a/setup.py b/setup.py index 70db6e4..453ec3e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.6.2', + version = '0.6.3', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',