Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat discriminator unet #550

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchviz import make_dot

# for FID
from data.base_dataset import get_transform
#from data.base_dataset import get_transform
from util.diff_aug import DiffAugment
from util.discriminator import DiscriminatorInfo

Expand Down Expand Up @@ -401,7 +401,6 @@ def compute_D_loss(self):
loss_name,
loss_value,
)

self.loss_D_tot += loss_value

def compute_G_loss_GAN_generic(
Expand Down Expand Up @@ -444,8 +443,9 @@ def compute_G_loss(self):
getattr(self, loss_function)()

def compute_G_loss_GAN(self):
"""Calculate GAN losses for generator(s)"""
"""Calculate GAN losses for generator(s)"""


for discriminator in self.discriminators:
if "mask" in discriminator.name:
continue
Expand All @@ -465,7 +465,7 @@ def compute_G_loss_GAN(self):
netD,
domain,
loss,
fake_name=fake_name,
fake_name=fake_name,
real_name=real_name,
)

Expand All @@ -479,7 +479,6 @@ def compute_G_loss_GAN(self):
loss_name,
loss_value,
)

self.loss_G_tot += loss_value

if self.opt.train_temporal_criterion:
Expand Down Expand Up @@ -562,11 +561,51 @@ def set_discriminators_info(self):
real_name = "temporal_real"
compute_every = self.opt.D_temporal_every

else:
elif "unet" in discriminator_name:
loss_calculator = loss.DualDiscriminatorGANLoss(
netD=getattr(self, "net"+ discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)
fake_name = None
real_name = None
compute_every = 1


elif "unet_discriminator_mha" in discriminator_name:
loss_calculator = loss.DualDiscriminatorGANLoss(
netD=getattr(self, "net"+ discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)
fake_name = None
real_name = None
compute_every = 1

else :
fake_name = None
real_name = None
compute_every = 1


if self.opt.train_use_contrastive_loss_D:
loss_calculator = (
loss.DiscriminatorContrastiveLoss(
Expand Down
195 changes: 195 additions & 0 deletions models/d_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from torch import nn
import torch
import functools
from torchinfo import summary


class UnetDiscriminator(nn.Module):
"""Create a Unet-based discriminator"""

def __init__(
self,
input_nc,
output_nc,
num_downs,
ngf=64,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
):
"""Construct a Unet discriminator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer, here ngf=64, so inner_nc=64*8=512
norm_layer -- normalization layer

We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetDiscriminator, self).__init__()
# construct unet structure
# add the innermost layer
unet_block = UnetSkipConnectionBlock(
ngf * 8,
ngf * 8,
input_nc=None,
submodule=None,
norm_layer=norm_layer,
innermost=True,
)
# add intermediate layers with ngf * 8 filters
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(
ngf * 8,
ngf * 8,
input_nc=None,
submodule=unet_block,
norm_layer=norm_layer,
use_dropout=use_dropout,
)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)

# add the outermost layer
self.model = UnetSkipConnectionBlock(
output_nc,
ngf,
input_nc=input_nc,
submodule=unet_block,
outermost=True,
norm_layer=norm_layer,
)


def compute_feats(self, input, extract_layer_ids=[]):
output, feats = self.model(input, feats=[])
return_feats = []
for i, feat in enumerate(feats):
if i in extract_layer_ids:
return_feats.append(feat)

return output, return_feats

def forward(self, input):
output, _ = self.compute_feats(input)
return output

def get_feats(self, input, extract_layer_ids=[]):
_, feats = self.compute_feats(input, extract_layer_ids)

return feats


class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""

def __init__(
self,
outer_nc,
inner_nc,
input_nc=None,
submodule=None,
outermost=False,
innermost=False,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
):
"""Construct a Unet submodule with skip connections.

Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(
input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
)
downrelu = nn.LeakyReLU(0.2,False)# True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(False)#True)
upnorm = norm_layer(outer_nc)

if outermost:
upconv = nn.ConvTranspose2d(
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1
)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(
inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(
inner_nc * 2,
outer_nc,
kernel_size=4,
stride=2,
padding=1,
bias=use_bias,
)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up

self.model = nn.Sequential(*model)

def forward(self, x, feats):
output = self.model[0](x)
return_feats = feats + [output]

for layer in self.model[1:]:
if isinstance(layer, UnetSkipConnectionBlock):
output, return_feats = layer(output, return_feats)
else:
output = layer(output)

if not self.outermost: # add skip connections
output = torch.cat([x, output], 1)

return output, return_feats


######### print architecture
input_par1=3
input_par2=3
input_par3=9
ins=UnetDiscriminator(input_nc=input_par1,output_nc=input_par2,num_downs=input_par3)
print(ins)

######### one example in detail
summary(ins, input_size=(3,1024,1024), batch_dim=0,col_names=["input_size", "output_size", "num_params", "kernel_size","mult_adds"], row_settings=["var_names"], depth=input_par3+1)
23 changes: 22 additions & 1 deletion models/gan_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
import torch.nn as nn
import functools
from torch.optim import lr_scheduler
Expand All @@ -19,6 +20,8 @@
from .modules.resnet_architecture.resnet_generator import ResnetGenerator_attn
from .modules.discriminators import NLayerDiscriminator
from .modules.discriminators import PixelDiscriminator
from .modules.discriminators import UnetDiscriminator


from .modules.classifiers import (
torch_model,
Expand Down Expand Up @@ -238,13 +241,17 @@ def define_G(
raise NotImplementedError(
"Generator model name [%s] is not recognized" % G_netG
)
print("netG is {}".format(net))
return init_net(net, model_init_type, model_init_gain)


def define_D(
D_netDs,
model_input_nc,
model_output_nc,
D_num_downs,
D_ndf,
D_ngf,
D_n_layers,
D_norm,
D_dropout,
Expand Down Expand Up @@ -273,7 +280,9 @@ def define_D(

Parameters:
model_input_nc (int) -- the number of channels in input images
model_output_nc (int) -- the number of channels in output images
D_ndf (int) -- the number of filters in the first conv layer
num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
netD (str) -- the architecture's name: basic | n_layers | pixel
D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
D_norm (str) -- the type of normalization layers used in the network.
Expand Down Expand Up @@ -432,11 +441,23 @@ def define_D(
)
return_nets[netD] = init_net(net, model_init_type, model_init_gain)

elif netD == "unet":
net = UnetDiscriminator(
model_input_nc,
model_output_nc,
D_num_downs, # the number of downsamplings
D_ngf, # the final conv has D_ngf*8=512 filter
norm_layer=norm_layer,
use_dropout=D_dropout,
)
return_nets[netD] = init_net(net, model_init_type, model_init_gain)


else:
raise NotImplementedError(
"Discriminator model name [%s] is not recognized" % netD
)

print("discriminator is {}".format(return_nets))
return return_nets


Expand Down
Loading
Loading