Skip to content

Commit

Permalink
feat: add sampling from the ema model
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 22, 2023
1 parent a30729e commit cf17f5e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Implementation of Consistency Models [(Song et al., 2023)](https://arxiv.org/abs

![image](./assets/consistency_models.png)

## Install
## Installation

```sh
$ pip install consistency
Expand All @@ -32,7 +32,9 @@ consistency = Consistency(
)

samples = consistency.sample(16)
samples = consistency.sample(16, steps=5) # multi-step generation

# multi-step sampling, sample from the ema model
samples = consistency.sample(16, steps=5, use_ema=True)
```

`Consistency` is self-contained with the training logic and all necessary schedules. It subclasses `LightningModule`, so it's supposed to be used with `Lightning.Trainer`.
Expand Down
18 changes: 16 additions & 2 deletions consistency/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
save_samples_every_n_epoch: int = 10,
num_samples: int = 16,
sample_steps: int = 1,
sample_ema: bool = False,
sample_seed: int = 0,
) -> None:
super().__init__()
Expand Down Expand Up @@ -253,6 +254,7 @@ def on_train_start(self) -> None:
num_samples=self.num_samples,
steps=self.sample_steps,
seed=self.sample_seed,
use_ema=self.sample_ema,
)

@rank_zero_only
Expand All @@ -266,6 +268,7 @@ def on_train_epoch_end(self) -> None:
num_samples=self.num_samples,
steps=self.sample_steps,
seed=self.sample_seed,
use_ema=self.sample_ema,
)

@torch.no_grad()
Expand Down Expand Up @@ -319,18 +322,29 @@ def sample(
images
+ math.sqrt(time.item() ** 2 - self.time_min**2) * noise
)
images = self(images, time[None])
images = self._forward(
self.model_ema if use_ema else self.model,
images,
time[None],
)

return images

@torch.no_grad()
def save_samples(
self,
filename: str,
num_samples: int = 16,
steps: int = 1,
use_ema: bool = False,
seed: int = 0,
):
samples = self.sample(num_samples=num_samples, steps=steps, seed=seed)
samples = self.sample(
num_samples=num_samples,
steps=steps,
use_ema=use_ema,
seed=seed,
)
samples.mul_(0.5).add_(0.5)
grid = make_grid(
samples,
Expand Down
2 changes: 2 additions & 0 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def parse_args():
type=int,
default=5,
)
parser.add_argument("--sample-ema", action="store_true")
parser.add_argument(
"--sample-seed",
type=int,
Expand Down Expand Up @@ -201,6 +202,7 @@ def __getitem__(self, index: int) -> torch.Tensor:
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

__version__ = "0.1.1"
__version__ = "0.1.2"

setup(
name="consistency",
Expand Down

0 comments on commit cf17f5e

Please sign in to comment.