Skip to content

Commit

Permalink
feat: consistency models with supervised losses
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 11, 2024
1 parent cd264de commit ed701ad
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from . import diffusion_networks
from .base_diffusion_model import BaseDiffusionModel

from piq import DISTS, LPIPS


def pseudo_huber_loss(input, target):
"""Computes the pseudo huber loss.
Expand Down Expand Up @@ -51,6 +53,34 @@ def modify_commandline_options(parser, is_train=True):
default=1000000,
help="number of steps before reaching the fully discretized consistency model sampling schedule",
)
parser.add_argument(
"--alg_cm_perceptual_loss",
type=str,
default=[""],
nargs="*",
choices=["", "LPIPS", "DISTS"],
help="optional supervised perceptual loss",
)
parser.add_argument(
"--alg_cm_lambda_perceptual",
type=float,
default=1.0,
help="weight for LPIPS and DISTS perceptual losses",
)
parser.add_argument(
"--alg_cm_dists_mean",
default=[0.485, 0.456, 0.406], # Imagenet default
nargs="*",
type=float,
help="mean for DISTS perceptual loss",
)
parser.add_argument(
"--alg_cm_dists_std",
default=[0.229, 0.224, 0.225], # Imagenet default
nargs="*",
type=float,
help="std for DISTS perceptual loss",
)

if is_train:
parser = CMModel.modify_commandline_options_train(parser)
Expand Down Expand Up @@ -145,12 +175,22 @@ def __init__(self, opt, rank):
self.networks_groups.append(self.group_G)

losses_G = []
if opt.alg_cm_perceptual_loss != [""]:
losses_G += ["G_perceptual"]
self.loss_names_G += losses_G
self.loss_names = self.loss_names_G.copy()

# Itercalculator
self.iter_calculator_init()

# perceptual losses
if "LPIPS" in self.opt.alg_cm_perceptual_loss:
self.criterionLPIPS = LPIPS().to(self.device)
if "DISTS" in self.opt.alg_cm_perceptual_loss:
self.criterionDISTS = DISTS(
mean=self.opt.alg_cm_dists_mean, std=self.opt.alg_cm_dists_std
).to(self.device)

def set_input(self, data):
if (
len(data["A"].to(self.device).shape) == 5
Expand Down Expand Up @@ -237,6 +277,24 @@ def compute_cm_loss(self):

self.loss_G_tot = loss * self.opt.alg_diffusion_lambda_G

# perceptual losses, if any
if "LPIPS" in self.opt.alg_cm_perceptual_loss:
self.loss_G_perceptual_lpips = torch.mean(
self.criterionLPIPS(y_0, mask_pred_x)
)
else:
self.loss_G_perceptual_lpips = 0
if "DISTS" in self.opt.alg_cm_perceptual_loss:
self.loss_G_perceptual_dists = self.criterionDISTS(y_0, mask_pred_x)
else:
self.loss_G_perceptual_dists = 0

if self.loss_G_perceptual_lpips > 0 or self.loss_G_perceptual_dists > 0:
self.loss_G_perceptual = self.opt.alg_cm_lambda_perceptual * (
self.loss_G_perceptual_lpips + self.loss_G_perceptual_dists
)
self.loss_G_tot += self.loss_G_perceptual

def inference(self, nb_imgs, offset=0):

if hasattr(self.netG_A, "module"):
Expand Down

0 comments on commit ed701ad

Please sign in to comment.