Skip to content

Commit

Permalink
feat: add LPIPSLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 22, 2023
1 parent 6faefdf commit 8a4b0ce
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 56 deletions.
19 changes: 19 additions & 0 deletions consistency/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Literal

from torch import nn
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity


class LPIPSLoss(nn.Module):
def __init__(self, net_type: Literal["vgg", "alex", "squeeze"] = "vgg"):
super().__init__()
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type=net_type)
self.lpips.requires_grad_(False)

@staticmethod
def clamp(x):
return x.clamp(-1, 1)

def forward(self, input, target):
lpips_loss = self.lpips(self.clamp(input), self.clamp(target))
return lpips_loss
186 changes: 131 additions & 55 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchvision import transforms

from consistency import Consistency
from consistency.loss import LPIPSLoss


def parse_args():
Expand Down Expand Up @@ -119,6 +120,8 @@ def parse_args():
type=int,
default=0,
)
parser.add_argument("--ckpt-path", type=str)
parser.add_argument("--wandb-id", type=str)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -164,64 +167,137 @@ def __getitem__(self, index: int) -> torch.Tensor:
num_workers=args.dataloader_num_workers,
)

consistency = Consistency(
model=UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
if args.ckpt_path:
consistency = Consistency.load_from_checkpoint(
checkpoint_path=args.ckpt_path,
model=UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
),
learning_rate=args.learning_rate,
data_std=args.data_std,
time_min=args.time_min,
time_max=args.time_max,
bins_min=args.bins_min,
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
)

trainer = Trainer(
accelerator="auto",
logger=WandbLogger(project="consistency", log_model=True),
callbacks=[
ModelCheckpoint(
dirpath="ckpt",
save_top_k=3,
monitor="loss",
loss_fn=LPIPSLoss(),
learning_rate=args.learning_rate,
data_std=args.data_std,
time_min=args.time_min,
time_max=args.time_max,
bins_min=args.bins_min,
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
)

trainer = Trainer(
accelerator="auto",
logger=WandbLogger(
project="consistency",
log_model=True,
id=args.wandb_id,
resume="must",
)
],
max_epochs=args.max_epochs,
precision=16,
log_every_n_steps=args.log_every_n_steps,
gradient_clip_algorithm="norm",
gradient_clip_val=1.0,
)
if args.wandb_id
else WandbLogger(
project="consistency",
log_model=True,
),
callbacks=[
ModelCheckpoint(
dirpath="ckpt",
save_top_k=3,
monitor="loss",
)
],
max_epochs=args.max_epochs,
precision=16,
log_every_n_steps=args.log_every_n_steps,
gradient_clip_algorithm="norm",
gradient_clip_val=1.0,
)
trainer.fit(consistency, dataloader, ckpt_path=args.ckpt_path)

else:
consistency = Consistency(
model=UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
),
loss_fn=LPIPSLoss(),
learning_rate=args.learning_rate,
data_std=args.data_std,
time_min=args.time_min,
time_max=args.time_max,
bins_min=args.bins_min,
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_ema=args.sample_ema,
sample_seed=args.sample_seed,
)

trainer = Trainer(
accelerator="auto",
logger=WandbLogger(project="consistency", log_model=True),
callbacks=[
ModelCheckpoint(
dirpath="ckpt",
save_top_k=3,
monitor="loss",
)
],
max_epochs=args.max_epochs,
precision=16,
log_every_n_steps=args.log_every_n_steps,
gradient_clip_algorithm="norm",
gradient_clip_val=1.0,
)

trainer.fit(consistency, dataloader)
trainer.fit(consistency, dataloader)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import find_packages, setup

__version__ = "0.1.2"
__version__ = "0.2.0"

setup(
name="consistency",
Expand All @@ -19,6 +19,8 @@
"torchvision",
"pytorch-lightning",
"diffusers",
"torchmetrics",
"lpips",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down

0 comments on commit 8a4b0ce

Please sign in to comment.