Skip to content

Commit

Permalink
feat: add clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 24, 2023
1 parent f9256f8 commit a90e14f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
15 changes: 13 additions & 2 deletions consistency/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,31 @@ def forward(
):
return self._forward(self.model, images, times)

def _forward(self, model: nn.Module, images: torch.Tensor, times: torch.Tensor):
def _forward(
self,
model: nn.Module,
images: torch.Tensor,
times: torch.Tensor,
clip: bool = True,
):
skip_coef = self.data_std**2 / (
(times - self.time_min).pow(2) + self.data_std**2
)
out_coef = self.data_std * times / (times.pow(2) + self.data_std**2).pow(0.5)

return self.image_time_product(
output = self.image_time_product(
images,
skip_coef,
) + self.image_time_product(
model(images, times),
out_coef,
)

if clip:
return output.clamp(-1.0, 1.0)

return output

def training_step(self, images: torch.Tensor, *args, **kwargs):
noise = torch.randn(images.shape, device=images.device)
timesteps = torch.randint(
Expand Down
7 changes: 4 additions & 3 deletions consistency/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Tuple, Union

import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -55,8 +56,8 @@ def forward(self, input, target):

return (
lpips_loss
+ self.overflow_weight * F(input, self.clamp(input))
+ self.l1_weight * F(input, target)
+ self.overflow_weight * F.l1_loss(input, self.clamp(input))
+ self.l1_weight * F.l1_loss(input, target)
)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import find_packages, setup

__version__ = "0.2.2"
__version__ = "0.2.3"

setup(
name="consistency",
Expand Down

0 comments on commit a90e14f

Please sign in to comment.