From 73460876c9288480dd860ea9b79e158f884aebaa Mon Sep 17 00:00:00 2001 From: junhsss Date: Tue, 28 Mar 2023 16:34:52 +0900 Subject: [PATCH] feat: accept only UNet2DModel variants --- consistency/consistency.py | 55 +++++++++----------------------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/consistency/consistency.py b/consistency/consistency.py index 7631cd5..40d9c6e 100644 --- a/consistency/consistency.py +++ b/consistency/consistency.py @@ -23,29 +23,13 @@ import wandb -class DiffusersWrapper(nn.Module): - def __init__(self, unet: UNet2DModel): - super().__init__() - self.unet = unet - - def forward( - self, - images: torch.Tensor, - times: torch.Tensor, - ): - out: UNet2DOutput = self.unet(images, times) - return out.sample - - class Consistency(LightningModule): def __init__( self, - model: nn.Module, + model: UNet2DModel, *, loss_fn: nn.Module = nn.MSELoss(), learning_rate: float = 1e-4, - image_size: Optional[int] = None, - channels: Optional[int] = None, data_std: float = 0.5, time_min: float = 0.002, time_max: float = 80.0, @@ -67,27 +51,10 @@ def __init__( ) -> None: super().__init__() - if isinstance(model, UNet2DModel): - if image_size: - raise TypeError("'image_size' is not supported for UNet2DModel") - if channels: - raise TypeError("'channels' is not supported for UNet2DModel") - - self.model = DiffusersWrapper(model) - self.model_ema = DiffusersWrapper(copy.deepcopy(model)) - self.image_size = model.sample_size - self.channels = model.in_channels - - else: - if not image_size: - raise TypeError("'image_size' should be provided.") - if not channels: - raise TypeError("'channels' should be provided.") - - self.model = model - self.model_ema = copy.deepcopy(model) - self.image_size = image_size - self.channels = channels + self.model = model + self.model_ema = copy.deepcopy(model) + self.image_size = model.sample_size + self.channels = model.in_channels self.model_ema.requires_grad_(False) @@ -169,18 +136,20 @@ def _forward( ) out_coef = self.data_std * times / (times.pow(2) + self.data_std**2).pow(0.5) - output = self.image_time_product( + out: UNet2DOutput = model(images, times) + + out = self.image_time_product( images, skip_coef, ) + self.image_time_product( - model(images, times), + out.sample, out_coef, ) if clip: - return output.clamp(-1.0, 1.0) + return out.clamp(-1.0, 1.0) - return output + return out def training_step(self, images: torch.Tensor, *args, **kwargs): _bins = self.bins @@ -246,7 +215,7 @@ def on_train_batch_end(self, *args, **kwargs): and self.trainer.global_step > 0 ): pipeline = ConsistencyPipeline( - unet=self.model_ema.unet if self.use_ema else self.model.unet, + unet=self.model_ema if self.use_ema else self.model, ) pipeline.save_pretrained(self.model_id)