From 2e0ec579e2339e4ff712d3c64427b76828001234 Mon Sep 17 00:00:00 2001 From: junhsss Date: Sat, 25 Mar 2023 06:50:56 +0900 Subject: [PATCH] feat: upscale before lpips inference --- consistency/loss.py | 6 +++++- setup.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/consistency/loss.py b/consistency/loss.py index 75c9731..57a0c70 100644 --- a/consistency/loss.py +++ b/consistency/loss.py @@ -42,8 +42,12 @@ def _append_net_type(net_type: str): self.l1_weight = l1_weight def forward(self, input, target): + upscaled_input = F.interpolate(input, (224, 224), mode="bilinear") + upscaled_target = F.interpolate(target, (224, 224), mode="bilinear") + lpips_loss = sum( - _lpips_loss(input, target) for _lpips_loss in self.lpips_losses + _lpips_loss(upscaled_input, upscaled_target) + for _lpips_loss in self.lpips_losses ) return lpips_loss + self.l1_weight * F.l1_loss(input, target) diff --git a/setup.py b/setup.py index add5720..b184236 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -__version__ = "0.2.3" +__version__ = "0.2.4" setup( name="consistency",