Skip to content

Commit

Permalink
fix: allowing for no NCE with cut
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 11, 2024
1 parent 92ad57d commit 189fc04
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def modify_commandline_options(parser, is_train=True):
type=util.str2bool,
nargs="?",
const=True,
default=True,
default=False,
help="use NCE loss for identity mapping: NCE(G(Y), Y))",
)

Expand Down Expand Up @@ -388,14 +388,18 @@ def __init__(self, opt, rank):
# Making groups
self.networks_groups = []

optimizers = ["optimizer_G", "optimizer_F"]
networks_to_optimize = ["G_A"]
optimizers = ["optimizer_G"]
if self.opt.alg_cut_lambda_NCE > 0.0:
optimizers.append("optimizer_F")
networks_to_optimize.append("F")
losses_backward = ["loss_G_tot"]

if self.opt.model_multimodal:
# optimizers.append("optimizer_E")
losses_backward.append("loss_G_z")
self.group_G = NetworkGroup(
networks_to_optimize=["G_A", "F"],
networks_to_optimize=networks_to_optimize,
forward_functions=["forward"],
backward_functions=["compute_G_loss"],
loss_names_list=["loss_names_G"],
Expand Down Expand Up @@ -431,7 +435,9 @@ def __init__(self, opt, rank):
self.set_discriminators_info()

# Losses names
losses_G = ["G_NCE"]
losses_G = []
if opt.alg_cut_lambda_NCE > 0.0:
losses_G += ["G_NCE"]
if opt.alg_cut_supervised_loss != [""]:
losses_G += ["G_supervised"]
losses_D = []
Expand Down

0 comments on commit 189fc04

Please sign in to comment.