Skip to content

Commit

Permalink
feat: accept only UNet2DModel variants
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 28, 2023
1 parent 7a74e8e commit 7346087
Showing 1 changed file with 12 additions and 43 deletions.
55 changes: 12 additions & 43 deletions consistency/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,13 @@
import wandb


class DiffusersWrapper(nn.Module):
def __init__(self, unet: UNet2DModel):
super().__init__()
self.unet = unet

def forward(
self,
images: torch.Tensor,
times: torch.Tensor,
):
out: UNet2DOutput = self.unet(images, times)
return out.sample


class Consistency(LightningModule):
def __init__(
self,
model: nn.Module,
model: UNet2DModel,
*,
loss_fn: nn.Module = nn.MSELoss(),
learning_rate: float = 1e-4,
image_size: Optional[int] = None,
channels: Optional[int] = None,
data_std: float = 0.5,
time_min: float = 0.002,
time_max: float = 80.0,
Expand All @@ -67,27 +51,10 @@ def __init__(
) -> None:
super().__init__()

if isinstance(model, UNet2DModel):
if image_size:
raise TypeError("'image_size' is not supported for UNet2DModel")
if channels:
raise TypeError("'channels' is not supported for UNet2DModel")

self.model = DiffusersWrapper(model)
self.model_ema = DiffusersWrapper(copy.deepcopy(model))
self.image_size = model.sample_size
self.channels = model.in_channels

else:
if not image_size:
raise TypeError("'image_size' should be provided.")
if not channels:
raise TypeError("'channels' should be provided.")

self.model = model
self.model_ema = copy.deepcopy(model)
self.image_size = image_size
self.channels = channels
self.model = model
self.model_ema = copy.deepcopy(model)
self.image_size = model.sample_size
self.channels = model.in_channels

self.model_ema.requires_grad_(False)

Expand Down Expand Up @@ -169,18 +136,20 @@ def _forward(
)
out_coef = self.data_std * times / (times.pow(2) + self.data_std**2).pow(0.5)

output = self.image_time_product(
out: UNet2DOutput = model(images, times)

out = self.image_time_product(
images,
skip_coef,
) + self.image_time_product(
model(images, times),
out.sample,
out_coef,
)

if clip:
return output.clamp(-1.0, 1.0)
return out.clamp(-1.0, 1.0)

return output
return out

def training_step(self, images: torch.Tensor, *args, **kwargs):
_bins = self.bins
Expand Down Expand Up @@ -246,7 +215,7 @@ def on_train_batch_end(self, *args, **kwargs):
and self.trainer.global_step > 0
):
pipeline = ConsistencyPipeline(
unet=self.model_ema.unet if self.use_ema else self.model.unet,
unet=self.model_ema if self.use_ema else self.model,
)

pipeline.save_pretrained(self.model_id)
Expand Down

0 comments on commit 7346087

Please sign in to comment.