Skip to content

Commit

Permalink
fix: diff across input/output channels in gans
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 9, 2024
1 parent 58edbef commit dad437e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,10 @@ def compute_G_loss_cut(self):
# Fake losses
if self.real_A.size(1) != self.fake_B.size(1):
# hack: fake_B and real_A do not have the same number of channels
diffc = self.fake_B.size(1) - self.real_A.size(1)
if self.real_A.size(1) > self.fake_B.size(1):
diffc = self.real_A.size(1) - self.fake_B.size(1)
else:
diffc = self.fake_B.size(1) - self.real_A.size(1)
assert diffc > 0
add1 = torch.zeros(
self.real_A.size(0), 1, self.real_A.size(2), self.real_A.size(3)
Expand Down

0 comments on commit dad437e

Please sign in to comment.