From a90e14ffc570b7415d231e92d3205e7083a2cb0e Mon Sep 17 00:00:00 2001 From: junhsss Date: Fri, 24 Mar 2023 21:28:34 +0900 Subject: [PATCH] feat: add clipping --- consistency/consistency.py | 15 +++++++++++++-- consistency/loss.py | 7 ++++--- setup.py | 2 +- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/consistency/consistency.py b/consistency/consistency.py index 3d340b2..c854428 100644 --- a/consistency/consistency.py +++ b/consistency/consistency.py @@ -115,13 +115,19 @@ def forward( ): return self._forward(self.model, images, times) - def _forward(self, model: nn.Module, images: torch.Tensor, times: torch.Tensor): + def _forward( + self, + model: nn.Module, + images: torch.Tensor, + times: torch.Tensor, + clip: bool = True, + ): skip_coef = self.data_std**2 / ( (times - self.time_min).pow(2) + self.data_std**2 ) out_coef = self.data_std * times / (times.pow(2) + self.data_std**2).pow(0.5) - return self.image_time_product( + output = self.image_time_product( images, skip_coef, ) + self.image_time_product( @@ -129,6 +135,11 @@ def _forward(self, model: nn.Module, images: torch.Tensor, times: torch.Tensor): out_coef, ) + if clip: + return output.clamp(-1.0, 1.0) + + return output + def training_step(self, images: torch.Tensor, *args, **kwargs): noise = torch.randn(images.shape, device=images.device) timesteps = torch.randint( diff --git a/consistency/loss.py b/consistency/loss.py index 96396b7..b21108a 100644 --- a/consistency/loss.py +++ b/consistency/loss.py @@ -1,4 +1,5 @@ -from typing import Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Tuple, Union import torch.nn.functional as F from torch import nn @@ -55,8 +56,8 @@ def forward(self, input, target): return ( lpips_loss - + self.overflow_weight * F(input, self.clamp(input)) - + self.l1_weight * F(input, target) + + self.overflow_weight * F.l1_loss(input, self.clamp(input)) + + self.l1_weight * F.l1_loss(input, target) ) diff --git a/setup.py b/setup.py index 9ac5260..add5720 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -__version__ = "0.2.2" +__version__ = "0.2.3" setup( name="consistency",