Skip to content

Commit

Permalink
feat: upscale before lpips inference
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 24, 2023
1 parent 6dad56b commit 2e0ec57
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion consistency/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ def _append_net_type(net_type: str):
self.l1_weight = l1_weight

def forward(self, input, target):
upscaled_input = F.interpolate(input, (224, 224), mode="bilinear")
upscaled_target = F.interpolate(target, (224, 224), mode="bilinear")

lpips_loss = sum(
_lpips_loss(input, target) for _lpips_loss in self.lpips_losses
_lpips_loss(upscaled_input, upscaled_target)
for _lpips_loss in self.lpips_losses
)

return lpips_loss + self.l1_weight * F.l1_loss(input, target)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import find_packages, setup

__version__ = "0.2.3"
__version__ = "0.2.4"

setup(
name="consistency",
Expand Down

0 comments on commit 2e0ec57

Please sign in to comment.