From 9d9e8ffecc1382f3b08fbe96f56da55a08eedee2 Mon Sep 17 00:00:00 2001 From: Rahul Goel <54021162+rahul-goel@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:09:48 +0200 Subject: [PATCH] Fused differentiable SSIM (#396) * add fused ssim * remove submdoule, add requirement --- examples/requirements.txt | 1 + examples/simple_trainer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 86dbea609..6ebe9abeb 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -17,3 +17,4 @@ tyro>=0.8.8 Pillow tensorboard pyyaml +git+https://github.com/rahul-goel/fused-ssim@84422e0da94c516220eb3acedb907e68809e9e01 diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ccf979dec..53c751575 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -20,6 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed @@ -552,8 +553,8 @@ def train(self): # loss l1loss = F.l1_loss(colors, pixels) - ssimloss = 1.0 - self.ssim( - pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda if cfg.depth_loss: