From 189fc04001ea53988ea2b84b54dc37eedce86bbd Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Wed, 11 Sep 2024 09:37:19 +0000 Subject: [PATCH] fix: allowing for no NCE with cut --- models/cut_model.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/models/cut_model.py b/models/cut_model.py index 352870648..ba8c9ddbd 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -62,7 +62,7 @@ def modify_commandline_options(parser, is_train=True): type=util.str2bool, nargs="?", const=True, - default=True, + default=False, help="use NCE loss for identity mapping: NCE(G(Y), Y))", ) @@ -388,14 +388,18 @@ def __init__(self, opt, rank): # Making groups self.networks_groups = [] - optimizers = ["optimizer_G", "optimizer_F"] + networks_to_optimize = ["G_A"] + optimizers = ["optimizer_G"] + if self.opt.alg_cut_lambda_NCE > 0.0: + optimizers.append("optimizer_F") + networks_to_optimize.append("F") losses_backward = ["loss_G_tot"] if self.opt.model_multimodal: # optimizers.append("optimizer_E") losses_backward.append("loss_G_z") self.group_G = NetworkGroup( - networks_to_optimize=["G_A", "F"], + networks_to_optimize=networks_to_optimize, forward_functions=["forward"], backward_functions=["compute_G_loss"], loss_names_list=["loss_names_G"], @@ -431,7 +435,9 @@ def __init__(self, opt, rank): self.set_discriminators_info() # Losses names - losses_G = ["G_NCE"] + losses_G = [] + if opt.alg_cut_lambda_NCE > 0.0: + losses_G += ["G_NCE"] if opt.alg_cut_supervised_loss != [""]: losses_G += ["G_supervised"] losses_D = []