From 838fcb3acc7dc6ff8b6e165c18bf0d512fd34f31 Mon Sep 17 00:00:00 2001 From: julie Date: Thu, 28 Sep 2023 11:14:07 +0200 Subject: [PATCH 01/15] feat: add discriminator unet --- models/modules/discriminators.py | 179 +++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 5662df779..4871fc3aa 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -151,3 +151,182 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): def forward(self, input): """Standard forward.""" return self.net(input) + + +class UnetDiscriminator(nn.Module): + """Create a Unet-based discriminator""" + + def __init__( + self, + input_nc, + output_nc, + num_downs, + ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet discriminator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetDiscriminator, self).__init__() + # construct unet structure + # add the innermost layer + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + ) + # add intermediate layers with ngf * 8 filters + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + use_dropout=use_dropout, + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + + # add the outermost layer + self.model = UnetSkipConnectionBlock( + output_nc, + ngf, + input_nc=input_nc, + submodule=unet_block, + outermost=True, + norm_layer=norm_layer, + ) + + def compute_feats(self, input, extract_layer_ids=[]): + output, feats = self.model(input, feats=[]) + return_feats = [] + for i, feat in enumerate(feats): + if i in extract_layer_ids: + return_feats.append(feat) + + return output, return_feats + + def forward(self, input): + output, _ = self.compute_feats(input) + return output + + def get_feats(self, input, extract_layer_ids=[]): + _, feats = self.compute_feats(input, extract_layer_ids) + + return feats + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + downrelu = nn.LeakyReLU(0.2, False) # True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(False) # True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d( + inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 + ) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d( + inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d( + inner_nc * 2, + outer_nc, + kernel_size=4, + stride=2, + padding=1, + bias=use_bias, + ) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x, feats): + output = self.model[0](x) + return_feats = feats + [output] + + for layer in self.model[1:]: + if isinstance(layer, UnetSkipConnectionBlock): + output, return_feats = layer(output, return_feats) + else: + output = layer(output) + + if not self.outermost: # add skip connections + output = torch.cat([x, output], 1) + + return output, return_feats From 95f96bae6f67df76e25c1462cbecb212bc57a2a2 Mon Sep 17 00:00:00 2001 From: julie Date: Thu, 28 Sep 2023 14:01:41 +0200 Subject: [PATCH 02/15] feat: add unet_128_d --- models/gan_networks.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/models/gan_networks.py b/models/gan_networks.py index af8266392..ba76ab4d4 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -1,4 +1,5 @@ import os +import torch import torch.nn as nn import functools from torch.optim import lr_scheduler @@ -19,6 +20,8 @@ from .modules.resnet_architecture.resnet_generator import ResnetGenerator_attn from .modules.discriminators import NLayerDiscriminator from .modules.discriminators import PixelDiscriminator +from .modules.discriminators import UnetDiscriminator + from .modules.classifiers import ( torch_model, @@ -244,6 +247,7 @@ def define_G( def define_D( D_netDs, model_input_nc, + model_output_nc, D_ndf, D_n_layers, D_norm, @@ -273,6 +277,7 @@ def define_D( Parameters: model_input_nc (int) -- the number of channels in input images + model_output_nc (int) -- the number of channels in output images D_ndf (int) -- the number of filters in the first conv layer netD (str) -- the architecture's name: basic | n_layers | pixel D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' @@ -432,6 +437,16 @@ def define_D( ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) + elif netD == "unet_128_d": + net = UnetDiscriminator( + model_input_nc, + model_output_nc, + 7, + D_ndf, + norm_layer=norm_layer, + use_dropout=D_dropout, + ) + else: raise NotImplementedError( "Discriminator model name [%s] is not recognized" % netD From 05c97f838a32078c52a5b408b96b7464004bbf18 Mon Sep 17 00:00:00 2001 From: julie Date: Thu, 28 Sep 2023 14:10:15 +0200 Subject: [PATCH 03/15] feat: update action functions ReLu and LeakyReLU into True --- models/modules/discriminators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 4871fc3aa..8c1125203 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -1,5 +1,5 @@ import functools - +import torch import numpy as np from torch import nn from torch.nn import functional as F @@ -278,9 +278,9 @@ def __init__( downconv = nn.Conv2d( input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) - downrelu = nn.LeakyReLU(0.2, False) # True) + downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) - uprelu = nn.ReLU(False) # True) + uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: From 4ad56fd6d7a7f75503644949111ae93ec8f60ab5 Mon Sep 17 00:00:00 2001 From: julie Date: Thu, 28 Sep 2023 14:30:33 +0200 Subject: [PATCH 04/15] feat: add option of unet_128_d in --D_netDs --- options/base_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/options/base_options.py b/options/base_options.py index 6ddfe3561..5d5a18163 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -392,6 +392,7 @@ def initialize(self, parser): "depth", "mask", "sam", + "unet_128_d", ] + list(TORCH_MODEL_CLASSES.keys()), help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam", From 3b56b7e59efe94033f773a8f88512ce59f127afa Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 10:24:01 +0200 Subject: [PATCH 05/15] feat: line34 add unet_128_d --- tests/test_run_nosemantic.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_run_nosemantic.py b/tests/test_run_nosemantic.py index 145c586a9..99a6575a1 100644 --- a/tests/test_run_nosemantic.py +++ b/tests/test_run_nosemantic.py @@ -31,7 +31,11 @@ "cycle_gan", ] -D_netDs = [["projected_d", "basic"], ["projected_d", "basic", "depth"]] +D_netDs = [ + ["projected_d", "basic"], + ["projected_d", "basic", "depth"], + ["projected_d", "basic", "unet_128_d"], +] train_feat_wavelet = [False, True] From a72ea0ce6e9b7fb862b5ed5ffb0cc0f569730b86 Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 10:24:01 +0200 Subject: [PATCH 06/15] feat: line34 add unet_128_d --- models/gan_networks.py | 9 +++++++-- tests/test_run_nosemantic.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/models/gan_networks.py b/models/gan_networks.py index ba76ab4d4..51e39ccc8 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -248,7 +248,9 @@ def define_D( D_netDs, model_input_nc, model_output_nc, + num_downs, D_ndf, + D_ngf, D_n_layers, D_norm, D_dropout, @@ -279,6 +281,8 @@ def define_D( model_input_nc (int) -- the number of channels in input images model_output_nc (int) -- the number of channels in output images D_ndf (int) -- the number of filters in the first conv layer + D_ngf(int) -- the number of filters in the last cov layer + num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck netD (str) -- the architecture's name: basic | n_layers | pixel D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' D_norm (str) -- the type of normalization layers used in the network. @@ -441,11 +445,12 @@ def define_D( net = UnetDiscriminator( model_input_nc, model_output_nc, - 7, - D_ndf, + num_downs=7, + D_ngf=64, norm_layer=norm_layer, use_dropout=D_dropout, ) + return_nets[netD] = init_net(net, model_init_type, model_init_gain) else: raise NotImplementedError( diff --git a/tests/test_run_nosemantic.py b/tests/test_run_nosemantic.py index 145c586a9..99a6575a1 100644 --- a/tests/test_run_nosemantic.py +++ b/tests/test_run_nosemantic.py @@ -31,7 +31,11 @@ "cycle_gan", ] -D_netDs = [["projected_d", "basic"], ["projected_d", "basic", "depth"]] +D_netDs = [ + ["projected_d", "basic"], + ["projected_d", "basic", "depth"], + ["projected_d", "basic", "unet_128_d"], +] train_feat_wavelet = [False, True] From 17c7089c17cc332b1bba28782a89e601815344a8 Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 14:46:52 +0200 Subject: [PATCH 07/15] feat: add num_donws and D_ngf=64 --- models/gan_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/gan_networks.py b/models/gan_networks.py index 51e39ccc8..e28154485 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -446,7 +446,7 @@ def define_D( model_input_nc, model_output_nc, num_downs=7, - D_ngf=64, + D_ngf=64, # the final conv has D_ngf*8=512 filter norm_layer=norm_layer, use_dropout=D_dropout, ) From bb21b0cf5285ee2c549762a25c3d1d540be190ae Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 15:56:22 +0200 Subject: [PATCH 08/15] feat: add num_downs and ngf --- models/d_unet.py | 195 +++++++++++++++++++++++++++++++++++++++++ models/gan_networks.py | 4 +- train_joligen_gan.sh | 6 ++ 3 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 models/d_unet.py create mode 100755 train_joligen_gan.sh diff --git a/models/d_unet.py b/models/d_unet.py new file mode 100644 index 000000000..8bb230c8b --- /dev/null +++ b/models/d_unet.py @@ -0,0 +1,195 @@ +from torch import nn +import torch +import functools +from torchinfo import summary + + +class UnetDiscriminator(nn.Module): + """Create a Unet-based discriminator""" + + def __init__( + self, + input_nc, + output_nc, + num_downs, + ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet discriminator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetDiscriminator, self).__init__() + # construct unet structure + # add the innermost layer + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + ) + # add intermediate layers with ngf * 8 filters + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + use_dropout=use_dropout, + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + + # add the outermost layer + self.model = UnetSkipConnectionBlock( + output_nc, + ngf, + input_nc=input_nc, + submodule=unet_block, + outermost=True, + norm_layer=norm_layer, + ) + + + def compute_feats(self, input, extract_layer_ids=[]): + output, feats = self.model(input, feats=[]) + return_feats = [] + for i, feat in enumerate(feats): + if i in extract_layer_ids: + return_feats.append(feat) + + return output, return_feats + + def forward(self, input): + output, _ = self.compute_feats(input) + return output + + def get_feats(self, input, extract_layer_ids=[]): + _, feats = self.compute_feats(input, extract_layer_ids) + + return feats + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + downrelu = nn.LeakyReLU(0.2,False)# True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(False)#True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d( + inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 + ) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d( + inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d( + inner_nc * 2, + outer_nc, + kernel_size=4, + stride=2, + padding=1, + bias=use_bias, + ) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x, feats): + output = self.model[0](x) + return_feats = feats + [output] + + for layer in self.model[1:]: + if isinstance(layer, UnetSkipConnectionBlock): + output, return_feats = layer(output, return_feats) + else: + output = layer(output) + + if not self.outermost: # add skip connections + output = torch.cat([x, output], 1) + + return output, return_feats + + +######### print architecture +input_par1=3 +input_par2=3 +input_par3=9 +ins=UnetDiscriminator(input_nc=input_par1,output_nc=input_par2,num_downs=input_par3) +print(ins) + +######### one example in detail +summary(ins, input_size=(3,1024,1024), batch_dim=0,col_names=["input_size", "output_size", "num_params", "kernel_size","mult_adds"], row_settings=["var_names"], depth=input_par3+1) diff --git a/models/gan_networks.py b/models/gan_networks.py index 266da845d..eceff6461 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -444,8 +444,8 @@ def define_D( net = UnetDiscriminator( model_input_nc, model_output_nc, - 7, # the number of downsamplings - D_ndf, # the final conv has D_ngf*8=512 filter + num_downs=7, # the number of downsamplings + ngf=64, # the final conv has D_ngf*8=512 filter norm_layer=norm_layer, use_dropout=D_dropout, ) diff --git a/train_joligen_gan.sh b/train_joligen_gan.sh new file mode 100755 index 000000000..439d62794 --- /dev/null +++ b/train_joligen_gan.sh @@ -0,0 +1,6 @@ +python3 train.py --dataroot /data3/juliew/joligen_gan/joliGEN_WIP/datasets/noglasses2glasses_ffhq\ + --checkpoints_dir ./checkpoints\ + --name noglasses2glasses\ + --output_display_env noglasses2glasses\ + --config_json examples/example_gan_noglasses2glasses.json\ + --D_netDs projected_d unet_128_d From 3cc9ef6e0a91d25ecd27918040fce56602f83819 Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 16:17:55 +0200 Subject: [PATCH 09/15] feat: add D_num_downs and D_ngf --- models/gan_networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/gan_networks.py b/models/gan_networks.py index eceff6461..08ecd8733 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -248,7 +248,7 @@ def define_D( D_netDs, model_input_nc, model_output_nc, - num_downs, + D_num_downs, D_ndf, D_ngf, D_n_layers, @@ -444,8 +444,8 @@ def define_D( net = UnetDiscriminator( model_input_nc, model_output_nc, - num_downs=7, # the number of downsamplings - ngf=64, # the final conv has D_ngf*8=512 filter + D_num_downs=7, # the number of downsamplings + D_ngf=64, # the final conv has D_ngf*8=512 filter norm_layer=norm_layer, use_dropout=D_dropout, ) From f8da6596256565362ba7e71c666a17d662590a25 Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 16:20:05 +0200 Subject: [PATCH 10/15] feat: add D_num_downs and D_ngf --- options/base_options.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/options/base_options.py b/options/base_options.py index 5d5a18163..197e07490 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -378,6 +378,20 @@ def initialize(self, parser): default=64, help="# of discrim filters in the first conv layer", ) + parser.add_argument( + "--D_ngf", + type=int, + default=64, + help="#*8 of discrim filters in the last conv layer", + ) + + parser.add_argument( + "--D_num_downs", + type=int, + default=7, + help="# of downsampling", + ) + parser.add_argument( "--D_netDs", type=str, From ddcc22bcb5f9ad08493e6cbb3d9246ff3c427d7c Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 17:42:58 +0200 Subject: [PATCH 11/15] feat: D_num_downs --- models/modules/discriminators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 8c1125203..063a518df 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -160,7 +160,7 @@ def __init__( self, input_nc, output_nc, - num_downs, + D_num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, @@ -169,7 +169,7 @@ def __init__( Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + D_num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 norm_layer -- normalization layer From 232ae288c4633d1250b701b4821fe575bc3c632d Mon Sep 17 00:00:00 2001 From: julie Date: Fri, 29 Sep 2023 17:47:33 +0200 Subject: [PATCH 12/15] feat: D_ngf and D_num_downs --- models/modules/discriminators.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 063a518df..be96ff0f2 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -161,7 +161,7 @@ def __init__( input_nc, output_nc, D_num_downs, - ngf=64, + D_ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, ): @@ -171,7 +171,7 @@ def __init__( output_nc (int) -- the number of channels in output images D_num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck - ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 + D_ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. @@ -181,18 +181,18 @@ def __init__( # construct unet structure # add the innermost layer unet_block = UnetSkipConnectionBlock( - ngf * 8, - ngf * 8, + D_ngf * 8, + D_ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, ) # add intermediate layers with ngf * 8 filters - for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + for i in range(D_num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( - ngf * 8, - ngf * 8, + D_ngf * 8, + D_ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, @@ -200,19 +200,27 @@ def __init__( ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( - ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + D_ngf * 4, + D_ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, ) unet_block = UnetSkipConnectionBlock( - ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + D_ngf * 2, + D_ngf * 4, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, ) unet_block = UnetSkipConnectionBlock( - ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + D_ngf, D_ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) # add the outermost layer self.model = UnetSkipConnectionBlock( output_nc, - ngf, + D_ngf, input_nc=input_nc, submodule=unet_block, outermost=True, From 4eae0212960a54492209dc8da8205c7a4cb72640 Mon Sep 17 00:00:00 2001 From: julie Date: Mon, 2 Oct 2023 10:45:47 +0200 Subject: [PATCH 13/15] feat: D_num_donws --- models/gan_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/gan_networks.py b/models/gan_networks.py index 08ecd8733..dd94d898d 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -444,8 +444,8 @@ def define_D( net = UnetDiscriminator( model_input_nc, model_output_nc, - D_num_downs=7, # the number of downsamplings - D_ngf=64, # the final conv has D_ngf*8=512 filter + D_num_downs, # the number of downsamplings + ngf=D_ngf, # the final conv has D_ngf*8=512 filter norm_layer=norm_layer, use_dropout=D_dropout, ) From 8cca1e08e294bf3d12e0b5428c530fa797ce7092 Mon Sep 17 00:00:00 2001 From: julie Date: Mon, 2 Oct 2023 10:47:20 +0200 Subject: [PATCH 14/15] feat: ngf --- models/gan_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/gan_networks.py b/models/gan_networks.py index dd94d898d..010c6a3d2 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -445,7 +445,7 @@ def define_D( model_input_nc, model_output_nc, D_num_downs, # the number of downsamplings - ngf=D_ngf, # the final conv has D_ngf*8=512 filter + D_ngf, # the final conv has D_ngf*8=512 filter norm_layer=norm_layer, use_dropout=D_dropout, ) From a4c692e3ecae92ca87fd5c909fb19477021035a1 Mon Sep 17 00:00:00 2001 From: julie Date: Tue, 24 Oct 2023 15:05:24 +0200 Subject: [PATCH 15/15] feat:finalizing the unet discriminator original from models/modules/unet_architecture/unet_generator.py --- data/base_dataset.py | 68 ++++++++++---------- models/base_gan_model.py | 51 +++++++++++++-- models/gan_networks.py | 6 +- models/modules/discriminators.py | 61 +++++++++++------- models/modules/loss.py | 103 +++++++++++++++++++++++++++++++ options/base_options.py | 2 +- 6 files changed, 226 insertions(+), 65 deletions(-) diff --git a/data/base_dataset.py b/data/base_dataset.py index a0b707392..d32c8727f 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -15,7 +15,7 @@ if torch.__version__[0] == "2": torchvision.disable_beta_transforms_warning() - from torchvision import datapoints + from torchvision import tv_tensors as datapoints from torchvision.transforms.v2 import functional as F2 import torchvision.transforms.functional as F @@ -442,15 +442,7 @@ def get_transform_ref( if convert: transform_list += [transforms.ToTensor()] - """if grayscale: - transform_list += [transforms.Normalize((0.5,), (0.5,))] - else: - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]""" - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ) return transforms.Compose(transform_list) @@ -573,8 +565,8 @@ def __call__(self, img, mask, bbox=None): w, h = img.size bbox = np.array([0, 0, w, h]) # sets bbox to full image size if torch.__version__[0] == "2": - tbbox = datapoints.BoundingBox( - bbox, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=img.size + tbbox = datapoints.BoundingBoxes( + bbox, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=img.size ) else: tbbox = bbox # placeholder @@ -1066,10 +1058,21 @@ class ComposeMaskList(transforms.Compose): >>> ]) """ - def __call__(self, imgs, masks=None): + def __call__(self, imgs, masks=None, bbox=None): + if bbox is None: + w, h = imgs[0].size + bbox = np.array([0, 0, w, h]) # sets bbox to full image size + if torch.__version__[0] == "2": + tbbox = datapoints.BoundingBoxes( + bbox, + format=datapoints.BoundingBoxFormat.XYXY, + canvas_size=imgs[0].size, + ) + else: + tbbox = bbox # placeholder for t in self.transforms: - imgs, masks = t(imgs, masks) - return imgs, masks + imgs, masks, tbbox = t(imgs, masks, tbbox) + return imgs, masks, tbbox class GrayscaleMaskList(transforms.Grayscale): @@ -1088,7 +1091,7 @@ class GrayscaleMaskList(transforms.Grayscale): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be converted to grayscale. @@ -1102,7 +1105,7 @@ def __call__(self, imgs, masks): F.to_grayscale(img, num_output_channels=self.num_output_channels) ) - return return_imgs, masks + return return_imgs, masks, bbox def __repr__(self): return self.__class__.__name__ + "(num_output_channels={0})".format( @@ -1123,7 +1126,7 @@ class ResizeMaskList(transforms.Resize): ``PIL.Image.BILINEAR`` """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be scaled. @@ -1131,6 +1134,7 @@ def __call__(self, imgs, masks): Returns: PIL Image: Rescaled image. """ + return_imgs = [] return_masks = [] @@ -1145,7 +1149,7 @@ def __call__(self, imgs, masks): return_masks.append( F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST) ) - return return_imgs, return_masks + return return_imgs, return_masks, F2.resize(bbox, self.size) class RandomCropMaskList(transforms.RandomCrop): @@ -1184,7 +1188,7 @@ class RandomCropMaskList(transforms.RandomCrop): """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be cropped. @@ -1216,7 +1220,7 @@ def __call__(self, imgs, masks): else: for mask in masks: return_masks.append(F.crop(mask, i, j, h, w)) - return return_imgs, return_masks + return return_imgs, return_masks, F2.crop(bbox, i, j, h, w) class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): @@ -1226,7 +1230,7 @@ class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): p (float): probability of the image being flipped. Default value is 0.5 """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be flipped. @@ -1244,9 +1248,9 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.hflip(bbox) else: - return imgs, masks + return imgs, masks, bbox class ToTensorMaskList(transforms.ToTensor): @@ -1260,7 +1264,7 @@ class ToTensorMaskList(transforms.ToTensor): In the other cases, tensors are returned without scaling. """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -1278,7 +1282,7 @@ def __call__(self, imgs, masks): ) else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, bbox.data class RandomRotationMaskList(transforms.RandomRotation): @@ -1316,7 +1320,7 @@ def get_params(degrees): return angle - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be rotated. @@ -1335,7 +1339,7 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.rotate(bbox, angle) class NormalizeMaskList(transforms.Normalize): @@ -1354,7 +1358,7 @@ class NormalizeMaskList(transforms.Normalize): """ - def __call__(self, tensor_imgs, tensor_masks): + def __call__(self, tensor_imgs, tensor_masks, tensor_bbox): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -1369,7 +1373,7 @@ def __call__(self, tensor_imgs, tensor_masks): F.normalize(tensor_img, self.mean, self.std, self.inplace) ) - return return_imgs, tensor_masks + return return_imgs, tensor_masks, tensor_bbox def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( @@ -1387,7 +1391,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear): self.scale_max = scale_max self.shear = shear - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): if random.random() > 1.0 - self.p: affine_params = self.get_params( (0, 0), @@ -1405,6 +1409,6 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.affine(bbox, *affine_params) else: - return imgs, masks + return imgs, masks, bbox diff --git a/models/base_gan_model.py b/models/base_gan_model.py index 8b69a7a7e..5ba8c7755 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -9,7 +9,7 @@ from torchviz import make_dot # for FID -from data.base_dataset import get_transform +#from data.base_dataset import get_transform from util.diff_aug import DiffAugment from util.discriminator import DiscriminatorInfo @@ -396,7 +396,6 @@ def compute_D_loss(self): loss_name, loss_value, ) - self.loss_D_tot += loss_value def compute_G_loss_GAN_generic( @@ -439,8 +438,9 @@ def compute_G_loss(self): getattr(self, loss_function)() def compute_G_loss_GAN(self): - """Calculate GAN losses for generator(s)""" + """Calculate GAN losses for generator(s)""" + for discriminator in self.discriminators: if "mask" in discriminator.name: continue @@ -460,7 +460,7 @@ def compute_G_loss_GAN(self): netD, domain, loss, - fake_name=fake_name, + fake_name=fake_name, real_name=real_name, ) @@ -474,7 +474,6 @@ def compute_G_loss_GAN(self): loss_name, loss_value, ) - self.loss_G_tot += loss_value if self.opt.train_temporal_criterion: @@ -586,11 +585,51 @@ def set_discriminators_info(self): real_name = "temporal_real" compute_every = self.opt.D_temporal_every - else: + elif "unet" in discriminator_name: + loss_calculator = loss.DualDiscriminatorGANLoss( + netD=getattr(self, "net"+ discriminator_name), + device=self.device, + dataaug_APA_p=self.opt.dataaug_APA_p, + dataaug_APA_target=self.opt.dataaug_APA_target, + train_batch_size=self.opt.train_batch_size, + dataaug_APA_nimg=self.opt.dataaug_APA_nimg, + dataaug_APA_every=self.opt.dataaug_APA_every, + dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth, + train_gan_mode=train_gan_mode, + dataaug_APA=self.opt.dataaug_APA, + dataaug_D_diffusion=dataaug_D_diffusion, + dataaug_D_diffusion_every=dataaug_D_diffusion_every, + ) fake_name = None real_name = None compute_every = 1 + + elif "unet_discriminator_mha" in discriminator_name: + loss_calculator = loss.DualDiscriminatorGANLoss( + netD=getattr(self, "net"+ discriminator_name), + device=self.device, + dataaug_APA_p=self.opt.dataaug_APA_p, + dataaug_APA_target=self.opt.dataaug_APA_target, + train_batch_size=self.opt.train_batch_size, + dataaug_APA_nimg=self.opt.dataaug_APA_nimg, + dataaug_APA_every=self.opt.dataaug_APA_every, + dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth, + train_gan_mode=train_gan_mode, + dataaug_APA=self.opt.dataaug_APA, + dataaug_D_diffusion=dataaug_D_diffusion, + dataaug_D_diffusion_every=dataaug_D_diffusion_every, + ) + fake_name = None + real_name = None + compute_every = 1 + + else : + fake_name = None + real_name = None + compute_every = 1 + + if self.opt.train_use_contrastive_loss_D: loss_calculator = ( loss.DiscriminatorContrastiveLoss( diff --git a/models/gan_networks.py b/models/gan_networks.py index 010c6a3d2..fc760bc3d 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -241,6 +241,7 @@ def define_G( raise NotImplementedError( "Generator model name [%s] is not recognized" % G_netG ) + print("netG is {}".format(net)) return init_net(net, model_init_type, model_init_gain) @@ -440,7 +441,7 @@ def define_D( ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) - elif netD == "unet_128_d": + elif netD == "unet": net = UnetDiscriminator( model_input_nc, model_output_nc, @@ -451,11 +452,12 @@ def define_D( ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) + else: raise NotImplementedError( "Discriminator model name [%s] is not recognized" % netD ) - + print("discriminator is {}".format(return_nets)) return return_nets diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index be96ff0f2..333ba8dc2 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -153,6 +153,7 @@ def forward(self, input): return self.net(input) + class UnetDiscriminator(nn.Module): """Create a Unet-based discriminator""" @@ -228,22 +229,17 @@ def __init__( ) def compute_feats(self, input, extract_layer_ids=[]): - output, feats = self.model(input, feats=[]) + output, feats, output_encoder_inside = self.model(input, feats=[]) return_feats = [] for i, feat in enumerate(feats): if i in extract_layer_ids: return_feats.append(feat) - return output, return_feats + return output, return_feats, output_encoder_inside def forward(self, input): - output, _ = self.compute_feats(input) - return output - - def get_feats(self, input, extract_layer_ids=[]): - _, feats = self.compute_feats(input, extract_layer_ids) - - return feats + output, _, output_encoder_inside = self.compute_feats(input) + return output, output_encoder_inside class UnetSkipConnectionBlock(nn.Module): @@ -263,24 +259,27 @@ def __init__( norm_layer=nn.BatchNorm2d, use_dropout=False, ): - """Construct a Unet submodule with skip connections. - - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost + self.innermost = innermost + + # Move the bottleneck conv layers initialization to innermost condition + if self.innermost: + self.bottleneck_conv_cor2 = nn.Conv2d( + inner_nc, outer_nc, kernel_size=2, stride=1, padding=0, bias=True + ) + self.bottleneck_conv_cor1 = nn.Conv2d( + inner_nc, outer_nc, kernel_size=1, stride=1, padding=0, bias=True + ) + + self.flatten = nn.Flatten() + self.tanh = nn.Tanh() + if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d( @@ -324,17 +323,31 @@ def __init__( self.model = nn.Sequential(*model) - def forward(self, x, feats): + + def forward(self, x, feats, output_encoder_inside=None): output = self.model[0](x) return_feats = feats + [output] for layer in self.model[1:]: if isinstance(layer, UnetSkipConnectionBlock): - output, return_feats = layer(output, return_feats) + output, return_feats, output_encoder_inside = layer( + output, return_feats, output_encoder_inside=output_encoder_inside + ) else: output = layer(output) + # Only apply the bottleneck convolutions if it's the innermost block + if self.innermost and isinstance(layer, nn.ReLU): + output_encoder = output + if hasattr(self, 'bottleneck_conv_cor2') and output_encoder.shape[2] == 2: + output_encoder_conv = self.bottleneck_conv_cor2(output_encoder) + elif hasattr(self, 'bottleneck_conv_cor1'): + output_encoder_conv = self.bottleneck_conv_cor1(output_encoder) + + output_encoder_inside = self.tanh(output_encoder_conv) + if not self.outermost: # add skip connections output = torch.cat([x, output], 1) - return output, return_feats + return output, return_feats, output_encoder_inside + diff --git a/models/modules/loss.py b/models/modules/loss.py index 13d810f36..b962aba22 100644 --- a/models/modules/loss.py +++ b/models/modules/loss.py @@ -239,6 +239,7 @@ def compute_loss_D(self, netD, real, fake, fake_2=None): def compute_loss_G(self, netD, real, fake): self.real = real + print() self.fake = fake def update(self, niter): @@ -394,6 +395,108 @@ def compute_loss_G(self, netD, real, fake): return loss_G +class DualDiscriminatorGANLoss(DiscriminatorLoss): + """ + Unet loss which includes loss from encoder and decoder, reference is 2002.12655 + """ + + def __init__( + self, + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_D_label_smooth, + train_gan_mode, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + ): + super().__init__( + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + ) + if dataaug_D_label_smooth: + target_real_label = 0.9 + else: + target_real_label = 1.0 + + self.gan_mode = train_gan_mode + + self.criterionGAN = GANLoss( + self.gan_mode, target_real_label=target_real_label + ).to(self.device) + + def compute_loss_D(self, netD, real, fake, fake_2): + """Calculate GAN loss for the discriminator + Parameters: + netD (network) -- the discriminator D + real (tensor array) -- real images + fake (tensor array) -- images generated by a generator + Return the discriminator loss. + We also call loss_D.backward() to calculate the gradients. + """ + super().compute_loss_D(netD, real, fake, fake_2) + + # Real + pred_real_pixel, pred_real_bottleneck = netD(self.real) + + loss_pred_real_pixel = self.criterionGAN(pred_real_pixel, True) + loss_pred_real_bottleneck = self.criterionGAN(pred_real_bottleneck, True) + self.loss_D_real = loss_pred_real_pixel + loss_pred_real_bottleneck + + # Fake + lambda_loss = 0.5 + pred_fake_pixel, pred_fake_bottleneck = netD(self.fake.detach()) + loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False) + loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False) + + loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel + + # Combined loss and calculate gradients + loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss + + ##########print the two difference loss value + loss_D_pixel = loss_pred_real_pixel + loss_D_fake_pixel + loss_D_bottleneck = loss_pred_real_bottleneck + loss_D_fake_bottleneck + + print("loss.py D loss of the pixel(decoder) real {} and fake {}, so the total is {} ".format( loss_pred_real_pixel, loss_D_fake_pixel,loss_D_pixel)) + print( + "loss.py D loss of the bottleneck(encoder) real {} and fake {}, so the total is {}".format( loss_pred_real_bottleneck, loss_D_fake_bottleneck, loss_D_bottleneck) + ) + + return loss_D + + def compute_loss_G(self, netD, real, fake): + + super().compute_loss_G(netD, real, fake) + print("net D here is {}".format(type(netD))) + pred_fake_pixel, pred_fake_bottleneck = netD(self.fake) + + loss_G_pixel = self.criterionGAN(pred_fake_pixel, True, relu=False) + loss_G_bottleneck = self.criterionGAN(pred_fake_bottleneck, True, relu=False) + ##############print the two difference loss values + + loss_D_fake = loss_G_pixel + loss_G_bottleneck + + print("loss.py G fake loss of the pixel(decoder) is {}".format(loss_G_pixel)) + print( + "loss.py G fake loss of the bottleneck(encoder) is {}".format(loss_G_bottleneck) + ) + return loss_D_fake + + class MultiScaleDiffusionLoss(nn.Module): """ Multiscale diffusion loss such as in 2301.11093. diff --git a/options/base_options.py b/options/base_options.py index 197e07490..73f0a58b2 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -406,7 +406,7 @@ def initialize(self, parser): "depth", "mask", "sam", - "unet_128_d", + "unet", ] + list(TORCH_MODEL_CLASSES.keys()), help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam",