diff --git a/models/base_gan_model.py b/models/base_gan_model.py index ce7f775b8..4c6956f16 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -9,7 +9,7 @@ from torchviz import make_dot # for FID -from data.base_dataset import get_transform +#from data.base_dataset import get_transform from util.diff_aug import DiffAugment from util.discriminator import DiscriminatorInfo @@ -401,7 +401,6 @@ def compute_D_loss(self): loss_name, loss_value, ) - self.loss_D_tot += loss_value def compute_G_loss_GAN_generic( @@ -444,8 +443,9 @@ def compute_G_loss(self): getattr(self, loss_function)() def compute_G_loss_GAN(self): - """Calculate GAN losses for generator(s)""" + """Calculate GAN losses for generator(s)""" + for discriminator in self.discriminators: if "mask" in discriminator.name: continue @@ -465,7 +465,7 @@ def compute_G_loss_GAN(self): netD, domain, loss, - fake_name=fake_name, + fake_name=fake_name, real_name=real_name, ) @@ -479,7 +479,6 @@ def compute_G_loss_GAN(self): loss_name, loss_value, ) - self.loss_G_tot += loss_value if self.opt.train_temporal_criterion: @@ -562,11 +561,51 @@ def set_discriminators_info(self): real_name = "temporal_real" compute_every = self.opt.D_temporal_every - else: + elif "unet" in discriminator_name: + loss_calculator = loss.DualDiscriminatorGANLoss( + netD=getattr(self, "net"+ discriminator_name), + device=self.device, + dataaug_APA_p=self.opt.dataaug_APA_p, + dataaug_APA_target=self.opt.dataaug_APA_target, + train_batch_size=self.opt.train_batch_size, + dataaug_APA_nimg=self.opt.dataaug_APA_nimg, + dataaug_APA_every=self.opt.dataaug_APA_every, + dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth, + train_gan_mode=train_gan_mode, + dataaug_APA=self.opt.dataaug_APA, + dataaug_D_diffusion=dataaug_D_diffusion, + dataaug_D_diffusion_every=dataaug_D_diffusion_every, + ) fake_name = None real_name = None compute_every = 1 + + elif "unet_discriminator_mha" in discriminator_name: + loss_calculator = loss.DualDiscriminatorGANLoss( + netD=getattr(self, "net"+ discriminator_name), + device=self.device, + dataaug_APA_p=self.opt.dataaug_APA_p, + dataaug_APA_target=self.opt.dataaug_APA_target, + train_batch_size=self.opt.train_batch_size, + dataaug_APA_nimg=self.opt.dataaug_APA_nimg, + dataaug_APA_every=self.opt.dataaug_APA_every, + dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth, + train_gan_mode=train_gan_mode, + dataaug_APA=self.opt.dataaug_APA, + dataaug_D_diffusion=dataaug_D_diffusion, + dataaug_D_diffusion_every=dataaug_D_diffusion_every, + ) + fake_name = None + real_name = None + compute_every = 1 + + else : + fake_name = None + real_name = None + compute_every = 1 + + if self.opt.train_use_contrastive_loss_D: loss_calculator = ( loss.DiscriminatorContrastiveLoss( diff --git a/models/d_unet.py b/models/d_unet.py new file mode 100644 index 000000000..8bb230c8b --- /dev/null +++ b/models/d_unet.py @@ -0,0 +1,195 @@ +from torch import nn +import torch +import functools +from torchinfo import summary + + +class UnetDiscriminator(nn.Module): + """Create a Unet-based discriminator""" + + def __init__( + self, + input_nc, + output_nc, + num_downs, + ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet discriminator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetDiscriminator, self).__init__() + # construct unet structure + # add the innermost layer + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + ) + # add intermediate layers with ngf * 8 filters + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + use_dropout=use_dropout, + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + + # add the outermost layer + self.model = UnetSkipConnectionBlock( + output_nc, + ngf, + input_nc=input_nc, + submodule=unet_block, + outermost=True, + norm_layer=norm_layer, + ) + + + def compute_feats(self, input, extract_layer_ids=[]): + output, feats = self.model(input, feats=[]) + return_feats = [] + for i, feat in enumerate(feats): + if i in extract_layer_ids: + return_feats.append(feat) + + return output, return_feats + + def forward(self, input): + output, _ = self.compute_feats(input) + return output + + def get_feats(self, input, extract_layer_ids=[]): + _, feats = self.compute_feats(input, extract_layer_ids) + + return feats + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + downrelu = nn.LeakyReLU(0.2,False)# True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(False)#True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d( + inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 + ) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d( + inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d( + inner_nc * 2, + outer_nc, + kernel_size=4, + stride=2, + padding=1, + bias=use_bias, + ) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x, feats): + output = self.model[0](x) + return_feats = feats + [output] + + for layer in self.model[1:]: + if isinstance(layer, UnetSkipConnectionBlock): + output, return_feats = layer(output, return_feats) + else: + output = layer(output) + + if not self.outermost: # add skip connections + output = torch.cat([x, output], 1) + + return output, return_feats + + +######### print architecture +input_par1=3 +input_par2=3 +input_par3=9 +ins=UnetDiscriminator(input_nc=input_par1,output_nc=input_par2,num_downs=input_par3) +print(ins) + +######### one example in detail +summary(ins, input_size=(3,1024,1024), batch_dim=0,col_names=["input_size", "output_size", "num_params", "kernel_size","mult_adds"], row_settings=["var_names"], depth=input_par3+1) diff --git a/models/gan_networks.py b/models/gan_networks.py index af8266392..fc760bc3d 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -1,4 +1,5 @@ import os +import torch import torch.nn as nn import functools from torch.optim import lr_scheduler @@ -19,6 +20,8 @@ from .modules.resnet_architecture.resnet_generator import ResnetGenerator_attn from .modules.discriminators import NLayerDiscriminator from .modules.discriminators import PixelDiscriminator +from .modules.discriminators import UnetDiscriminator + from .modules.classifiers import ( torch_model, @@ -238,13 +241,17 @@ def define_G( raise NotImplementedError( "Generator model name [%s] is not recognized" % G_netG ) + print("netG is {}".format(net)) return init_net(net, model_init_type, model_init_gain) def define_D( D_netDs, model_input_nc, + model_output_nc, + D_num_downs, D_ndf, + D_ngf, D_n_layers, D_norm, D_dropout, @@ -273,7 +280,9 @@ def define_D( Parameters: model_input_nc (int) -- the number of channels in input images + model_output_nc (int) -- the number of channels in output images D_ndf (int) -- the number of filters in the first conv layer + num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck netD (str) -- the architecture's name: basic | n_layers | pixel D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' D_norm (str) -- the type of normalization layers used in the network. @@ -432,11 +441,23 @@ def define_D( ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) + elif netD == "unet": + net = UnetDiscriminator( + model_input_nc, + model_output_nc, + D_num_downs, # the number of downsamplings + D_ngf, # the final conv has D_ngf*8=512 filter + norm_layer=norm_layer, + use_dropout=D_dropout, + ) + return_nets[netD] = init_net(net, model_init_type, model_init_gain) + + else: raise NotImplementedError( "Discriminator model name [%s] is not recognized" % netD ) - + print("discriminator is {}".format(return_nets)) return return_nets diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 5662df779..333ba8dc2 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -1,5 +1,5 @@ import functools - +import torch import numpy as np from torch import nn from torch.nn import functional as F @@ -151,3 +151,203 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): def forward(self, input): """Standard forward.""" return self.net(input) + + + +class UnetDiscriminator(nn.Module): + """Create a Unet-based discriminator""" + + def __init__( + self, + input_nc, + output_nc, + D_num_downs, + D_ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet discriminator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + D_num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + D_ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512 + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetDiscriminator, self).__init__() + # construct unet structure + # add the innermost layer + unet_block = UnetSkipConnectionBlock( + D_ngf * 8, + D_ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + ) + # add intermediate layers with ngf * 8 filters + for i in range(D_num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock( + D_ngf * 8, + D_ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + use_dropout=use_dropout, + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + D_ngf * 4, + D_ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + ) + unet_block = UnetSkipConnectionBlock( + D_ngf * 2, + D_ngf * 4, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + ) + unet_block = UnetSkipConnectionBlock( + D_ngf, D_ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + + # add the outermost layer + self.model = UnetSkipConnectionBlock( + output_nc, + D_ngf, + input_nc=input_nc, + submodule=unet_block, + outermost=True, + norm_layer=norm_layer, + ) + + def compute_feats(self, input, extract_layer_ids=[]): + output, feats, output_encoder_inside = self.model(input, feats=[]) + return_feats = [] + for i, feat in enumerate(feats): + if i in extract_layer_ids: + return_feats.append(feat) + + return output, return_feats, output_encoder_inside + + def forward(self, input): + output, _, output_encoder_inside = self.compute_feats(input) + return output, output_encoder_inside + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + self.innermost = innermost + + # Move the bottleneck conv layers initialization to innermost condition + if self.innermost: + self.bottleneck_conv_cor2 = nn.Conv2d( + inner_nc, outer_nc, kernel_size=2, stride=1, padding=0, bias=True + ) + self.bottleneck_conv_cor1 = nn.Conv2d( + inner_nc, outer_nc, kernel_size=1, stride=1, padding=0, bias=True + ) + + self.flatten = nn.Flatten() + self.tanh = nn.Tanh() + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d( + inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 + ) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d( + inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d( + inner_nc * 2, + outer_nc, + kernel_size=4, + stride=2, + padding=1, + bias=use_bias, + ) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + + def forward(self, x, feats, output_encoder_inside=None): + output = self.model[0](x) + return_feats = feats + [output] + + for layer in self.model[1:]: + if isinstance(layer, UnetSkipConnectionBlock): + output, return_feats, output_encoder_inside = layer( + output, return_feats, output_encoder_inside=output_encoder_inside + ) + else: + output = layer(output) + + # Only apply the bottleneck convolutions if it's the innermost block + if self.innermost and isinstance(layer, nn.ReLU): + output_encoder = output + if hasattr(self, 'bottleneck_conv_cor2') and output_encoder.shape[2] == 2: + output_encoder_conv = self.bottleneck_conv_cor2(output_encoder) + elif hasattr(self, 'bottleneck_conv_cor1'): + output_encoder_conv = self.bottleneck_conv_cor1(output_encoder) + + output_encoder_inside = self.tanh(output_encoder_conv) + + if not self.outermost: # add skip connections + output = torch.cat([x, output], 1) + + return output, return_feats, output_encoder_inside + diff --git a/models/modules/loss.py b/models/modules/loss.py index 13d810f36..b962aba22 100644 --- a/models/modules/loss.py +++ b/models/modules/loss.py @@ -239,6 +239,7 @@ def compute_loss_D(self, netD, real, fake, fake_2=None): def compute_loss_G(self, netD, real, fake): self.real = real + print() self.fake = fake def update(self, niter): @@ -394,6 +395,108 @@ def compute_loss_G(self, netD, real, fake): return loss_G +class DualDiscriminatorGANLoss(DiscriminatorLoss): + """ + Unet loss which includes loss from encoder and decoder, reference is 2002.12655 + """ + + def __init__( + self, + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_D_label_smooth, + train_gan_mode, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + ): + super().__init__( + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + ) + if dataaug_D_label_smooth: + target_real_label = 0.9 + else: + target_real_label = 1.0 + + self.gan_mode = train_gan_mode + + self.criterionGAN = GANLoss( + self.gan_mode, target_real_label=target_real_label + ).to(self.device) + + def compute_loss_D(self, netD, real, fake, fake_2): + """Calculate GAN loss for the discriminator + Parameters: + netD (network) -- the discriminator D + real (tensor array) -- real images + fake (tensor array) -- images generated by a generator + Return the discriminator loss. + We also call loss_D.backward() to calculate the gradients. + """ + super().compute_loss_D(netD, real, fake, fake_2) + + # Real + pred_real_pixel, pred_real_bottleneck = netD(self.real) + + loss_pred_real_pixel = self.criterionGAN(pred_real_pixel, True) + loss_pred_real_bottleneck = self.criterionGAN(pred_real_bottleneck, True) + self.loss_D_real = loss_pred_real_pixel + loss_pred_real_bottleneck + + # Fake + lambda_loss = 0.5 + pred_fake_pixel, pred_fake_bottleneck = netD(self.fake.detach()) + loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False) + loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False) + + loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel + + # Combined loss and calculate gradients + loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss + + ##########print the two difference loss value + loss_D_pixel = loss_pred_real_pixel + loss_D_fake_pixel + loss_D_bottleneck = loss_pred_real_bottleneck + loss_D_fake_bottleneck + + print("loss.py D loss of the pixel(decoder) real {} and fake {}, so the total is {} ".format( loss_pred_real_pixel, loss_D_fake_pixel,loss_D_pixel)) + print( + "loss.py D loss of the bottleneck(encoder) real {} and fake {}, so the total is {}".format( loss_pred_real_bottleneck, loss_D_fake_bottleneck, loss_D_bottleneck) + ) + + return loss_D + + def compute_loss_G(self, netD, real, fake): + + super().compute_loss_G(netD, real, fake) + print("net D here is {}".format(type(netD))) + pred_fake_pixel, pred_fake_bottleneck = netD(self.fake) + + loss_G_pixel = self.criterionGAN(pred_fake_pixel, True, relu=False) + loss_G_bottleneck = self.criterionGAN(pred_fake_bottleneck, True, relu=False) + ##############print the two difference loss values + + loss_D_fake = loss_G_pixel + loss_G_bottleneck + + print("loss.py G fake loss of the pixel(decoder) is {}".format(loss_G_pixel)) + print( + "loss.py G fake loss of the bottleneck(encoder) is {}".format(loss_G_bottleneck) + ) + return loss_D_fake + + class MultiScaleDiffusionLoss(nn.Module): """ Multiscale diffusion loss such as in 2301.11093. diff --git a/options/base_options.py b/options/base_options.py index 2aba144ce..1cb0275a4 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -379,6 +379,20 @@ def initialize(self, parser): default=64, help="# of discrim filters in the first conv layer", ) + parser.add_argument( + "--D_ngf", + type=int, + default=64, + help="#*8 of discrim filters in the last conv layer", + ) + + parser.add_argument( + "--D_num_downs", + type=int, + default=7, + help="# of downsampling", + ) + parser.add_argument( "--D_netDs", type=str, @@ -393,6 +407,7 @@ def initialize(self, parser): "depth", "mask", "sam", + "unet", ] + list(TORCH_MODEL_CLASSES.keys()), help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam", diff --git a/tests/test_run_nosemantic.py b/tests/test_run_nosemantic.py index e15c1edc9..62a2460a4 100644 --- a/tests/test_run_nosemantic.py +++ b/tests/test_run_nosemantic.py @@ -32,7 +32,11 @@ "cycle_gan", ] -D_netDs = [["projected_d", "basic"], ["projected_d", "basic", "depth"]] +D_netDs = [ + ["projected_d", "basic"], + ["projected_d", "basic", "depth"], + ["projected_d", "basic", "unet_128_d"], +] train_feat_wavelet = [False, True] diff --git a/train_joligen_gan.sh b/train_joligen_gan.sh new file mode 100755 index 000000000..439d62794 --- /dev/null +++ b/train_joligen_gan.sh @@ -0,0 +1,6 @@ +python3 train.py --dataroot /data3/juliew/joligen_gan/joliGEN_WIP/datasets/noglasses2glasses_ffhq\ + --checkpoints_dir ./checkpoints\ + --name noglasses2glasses\ + --output_display_env noglasses2glasses\ + --config_json examples/example_gan_noglasses2glasses.json\ + --D_netDs projected_d unet_128_d