Skip to content

Commit

Permalink
allow for coercing dtype, in the case that the online model weights c…
Browse files Browse the repository at this point in the history
…hanged dtype
  • Loading branch information
lucidrains committed Sep 28, 2024
1 parent 866a3e8 commit 11c5931
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
31 changes: 26 additions & 5 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@ def exists(val):
def get_module_device(m: Module):
return next(m.parameters()).device

def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
def maybe_coerce_dtype(t, dtype):
if t.dtype == dtype:
return t

return t.to(dtype)

def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
if auto_move_device:
src = src.to(tgt.device)

if coerce_dtype:
src = maybe_coerce_dtype(src, tgt.dtype)

tgt.copy_(src)

def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
if auto_move_device:
src = src.to(tgt.device)

if coerce_dtype:
src = maybe_coerce_dtype(src, tgt.dtype)

tgt.lerp_(src, weight)

class EMA(Module):
Expand Down Expand Up @@ -64,7 +76,8 @@ def __init__(
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
use_foreach = False,
forward_method_names: Tuple[str, ...] = (),
move_ema_to_online_device = False
move_ema_to_online_device = False,
coerce_dtype = False
):
super().__init__()
self.beta = beta
Expand Down Expand Up @@ -108,8 +121,8 @@ def __init__(

# tensor update functions

self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)

# updating hyperparameters

Expand All @@ -130,6 +143,10 @@ def __init__(

self.allow_different_devices = allow_different_devices

# whether to coerce dtype when copy or lerp from online to EMA model

self.coerce_dtype = coerce_dtype

# whether to move EMA model to online model device automatically

self.move_ema_to_online_device = move_ema_to_online_device
Expand Down Expand Up @@ -279,6 +296,10 @@ def update_moving_average(self, ma_model, current_model):
tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]

if self.coerce_dtype:
tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]

if len(tensors_to_copy) > 0:
tgt_copy, src_copy = zip(*tensors_to_copy)
torch._foreach_copy_(tgt_copy, src_copy)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.6.2',
version = '0.6.3',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 11c5931

Please sign in to comment.