Skip to content

Commit

Permalink
just remove beartype, cleanup types
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 17, 2024
1 parent a1b3342 commit 2c36ca6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 18 deletions.
4 changes: 1 addition & 3 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from torch import nn, Tensor
from torch.nn import Module

from beartype import beartype
from beartype.typing import Set
from typing import Set

def exists(val):
return val is not None
Expand Down Expand Up @@ -46,7 +45,6 @@ class EMA(Module):
min_value (float): The minimum EMA decay rate. Default: 0.
"""

@beartype
def __init__(
self,
model: Module,
Expand Down
24 changes: 11 additions & 13 deletions ema_pytorch/post_hoc_ema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path
from copy import deepcopy
from functools import partial
Expand All @@ -8,8 +10,7 @@

import numpy as np

from beartype import beartype
from beartype.typing import Set, Tuple, Optional
from typing import Set, Tuple

def exists(val):
return val is not None
Expand Down Expand Up @@ -47,13 +48,12 @@ class KarrasEMA(Module):
can either use gamma or sigma_rel from paper
"""

@beartype
def __init__(
self,
model: Module,
sigma_rel: Optional[float] = None,
gamma: Optional[float] = None,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
sigma_rel: float | None = None,
gamma: float | None = None,
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
update_every: int = 100,
frozen: bool = False,
param_or_buffer_names_no_ema: Set[str] = set(),
Expand Down Expand Up @@ -259,12 +259,11 @@ def solve_weights(t_i, gamma_i, t_r, gamma_r):

class PostHocEMA(Module):

@beartype
def __init__(
self,
model: Module,
sigma_rels: Optional[Tuple[float, ...]] = None,
gammas: Optional[Tuple[float, ...]] = None,
sigma_rels: Tuple[float, ...] | None = None,
gammas: Tuple[float, ...] | None = None,
checkpoint_every_num_steps: int = 1000,
checkpoint_folder: str = './post-hoc-ema-checkpoints',
checkpoint_dtype: torch.dtype = torch.float16,
Expand Down Expand Up @@ -326,12 +325,11 @@ def checkpoint(self):
pkg = deepcopy(ema_model).to(self.checkpoint_dtype).state_dict()
torch.save(pkg, str(path))

@beartype
def synthesize_ema_model(
self,
gamma: Optional[float] = None,
sigma_rel: Optional[float] = None,
step: Optional[int] = None,
gamma: float | None = None,
sigma_rel: float | None = None,
step: int | None = None,
) -> KarrasEMA:
assert exists(gamma) ^ exists(sigma_rel)
device = self.device
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.0',
version = '0.5.1',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand All @@ -16,7 +16,6 @@
'exponential moving average'
],
install_requires=[
'beartype',
'torch>=1.6',
],
classifiers=[
Expand Down

0 comments on commit 2c36ca6

Please sign in to comment.