Skip to content

Commit

Permalink
Fused differentiable SSIM (#396)
Browse files Browse the repository at this point in the history
* add fused ssim

* remove submdoule, add requirement
  • Loading branch information
rahul-goel authored Sep 11, 2024
1 parent d0dca4f commit 9d9e8ff
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tyro>=0.8.8
Pillow
tensorboard
pyyaml
git+https://github.com/rahul-goel/fused-ssim@84422e0da94c516220eb3acedb907e68809e9e01
5 changes: 3 additions & 2 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from fused_ssim import fused_ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing_extensions import Literal, assert_never
from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed
Expand Down Expand Up @@ -552,8 +553,8 @@ def train(self):

# loss
l1loss = F.l1_loss(colors, pixels)
ssimloss = 1.0 - self.ssim(
pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2)
ssimloss = 1.0 - fused_ssim(
colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid"
)
loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
if cfg.depth_loss:
Expand Down

0 comments on commit 9d9e8ff

Please sign in to comment.