From aab50c9ccb90ede0ff0827aac0ba1313db4c0b72 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:17:01 -0500 Subject: [PATCH 01/31] updating architecture --- patchgan/__init__.py | 10 +- patchgan/disc.py | 49 ++++++++++ patchgan/trainer.py | 164 +++++++++++++++++--------------- patchgan/unet.py | 222 +------------------------------------------ 4 files changed, 143 insertions(+), 302 deletions(-) create mode 100644 patchgan/disc.py diff --git a/patchgan/__init__.py b/patchgan/__init__.py index d7bdc41..f9bfc38 100644 --- a/patchgan/__init__.py +++ b/patchgan/__init__.py @@ -1,5 +1,7 @@ -from .unet import * -from .io import * -from .losses import * -from .utils import * +from .unet import UNet +from .disc import Discriminator from .trainer import Trainer + +__all__ = [ + 'UNet', 'Discriminator', 'Trainer' +] diff --git a/patchgan/disc.py b/patchgan/disc.py new file mode 100644 index 0000000..66b6772 --- /dev/null +++ b/patchgan/disc.py @@ -0,0 +1,49 @@ +from torch import nn +from .transfer import Transferable + + +class Discriminator(nn.Module, Transferable): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(Discriminator, self).__init__() + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw, bias=False), + nn.LeakyReLU(0.2, True), + norm_layer(ndf * nf_mult) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw, bias=False), + nn.LeakyReLU(0.2, True), + norm_layer(ndf * nf_mult) + ] + + # output 1 channel prediction map + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, + stride=1, padding=padw), nn.Sigmoid()] + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) diff --git a/patchgan/trainer.py b/patchgan/trainer.py index 97335c4..ce53f30 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -2,49 +2,31 @@ import os import tqdm import glob +import numpy as np from torch import optim from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau -from .losses import fc_tversky, adv_loss +from .losses import fc_tversky, bce_loss +from torch.nn.functional import binary_cross_entropy from collections import defaultdict -import numpy as np device = 'cuda' if torch.cuda.is_available() else 'cpu' -# custom weights initialization called on generator and discriminator -# scaling here means std -def weights_init(net, init_type='normal', scaling=0.02): - """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might - work better for some applications. Feel free to try yourself. - """ - def init_func(m): # define the initialization function - classname = m.__class__.__name__ - if hasattr(m, 'weight') and (classname.find('Conv')) != -1: - torch.nn.init.normal_(m.weight.data, 0.0, scaling) - # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, scaling) - torch.nn.init.constant_(m.bias.data, 0.0) - - class Trainer: ''' Trainer module which contains both the full training driver which calls the train_batch method ''' - disc_alpha = 1. - fc_gamma = 0.75 - fc_beta = 0.7 + + seg_alpha = 200 + loss_type = 'tversky' + tversky_beta = 0.75 + tversky_gamma = 0.75 neptune_config = None - def __init__(self, generator, discriminator, savefolder): + def __init__(self, generator, discriminator, savefolder, device='cuda'): ''' Store the generator and discriminator info ''' @@ -54,6 +36,7 @@ def __init__(self, generator, discriminator, savefolder): self.generator = generator self.discriminator = discriminator + self.device = device if savefolder[-1] != '/': savefolder += '/' @@ -68,23 +51,36 @@ def batch(self, x, y, train=False): ''' Train the generator and discriminator on a single batch ''' - torch.autograd.set_detect_anomaly(True) - # convert the input image and mask to tensors - img_tensor = torch.as_tensor(x, dtype=torch.float).to(device) - target_tensor = torch.as_tensor(y, dtype=torch.float).to(device) + if not isinstance(x, torch.Tensor): + input_tensor = torch.as_tensor(x, dtype=torch.float).to(self.device) + target_tensor = torch.as_tensor(y, dtype=torch.float).to(self.device) + else: + input_tensor = x.to(self.device, non_blocking=True) + target_tensor = y.to(self.device, non_blocking=True) - gen_img = self.generator(img_tensor) + # train the generator + gen_img = self.generator(input_tensor) - disc_inp_fake = torch.cat((img_tensor, gen_img), 1) + disc_inp_fake = torch.cat((input_tensor, gen_img), 1) disc_fake = self.discriminator(disc_inp_fake) labels_real = torch.full(disc_fake.shape, 1, dtype=torch.float, device=device) labels_fake = torch.full(disc_fake.shape, 0, dtype=torch.float, device=device) - gen_loss_tversky = fc_tversky(target_tensor, gen_img, beta=self.fc_beta, gamma=self.fc_gamma) - gen_loss_disc = adv_loss(disc_fake, labels_real) - gen_loss = gen_loss_tversky + self.disc_alpha * gen_loss_disc + if self.loss_type == 'tversky': + gen_loss = fc_tversky(target_tensor, gen_img, + beta=self.tversky_beta, + gamma=self.tversky_gamma) * self.seg_alpha + elif self.loss_type == 'weighted_bce': + if gen_img.shape[1] > 1: + weight = 1 - torch.sum(target_tensor, dim=(2, 3), keepdim=True) / torch.sum(target_tensor) + else: + weight = torch.ones_like(target_tensor) + gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.seg_alpha + + gen_loss_disc = bce_loss(disc_fake, labels_real) + gen_loss = gen_loss + gen_loss_disc if train: self.generator.zero_grad() @@ -92,33 +88,33 @@ def batch(self, x, y, train=False): self.gen_optimizer.step() # Train the discriminator - # On the real image if train: self.discriminator.zero_grad() - disc_inp_real = torch.cat((img_tensor, target_tensor), 1) + disc_inp_real = torch.cat((input_tensor, target_tensor), 1) disc_real = self.discriminator(disc_inp_real) - disc_inp_fake = torch.cat((img_tensor, gen_img.detach()), 1) + disc_inp_fake = torch.cat((input_tensor, gen_img.detach()), 1) disc_fake = self.discriminator(disc_inp_fake) - loss_real = adv_loss(disc_real, labels_real.detach()) - loss_fake = adv_loss(disc_fake, labels_fake) + loss_real = bce_loss(disc_real, labels_real.detach()) + loss_fake = bce_loss(disc_fake, labels_fake) disc_loss = (loss_fake + loss_real) / 2. if train: disc_loss.backward() self.disc_optimizer.step() - keys = ['gen', 'tversky', 'gdisc', 'discr', 'discf', 'disc'] - mean_loss_i = [gen_loss.item(), gen_loss_tversky.item(), gen_loss_disc.item(), + keys = ['gen', 'gen_loss', 'gdisc', 'discr', 'discf', 'disc'] + mean_loss_i = [gen_loss.item(), gen_loss.item(), gen_loss_disc.item(), loss_real.item(), loss_fake.item(), disc_loss.item()] loss = {key: val for key, val in zip(keys, mean_loss_i)} return loss - def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, - gen_learning_rate=1.e-3, save_freq=10, lr_decay=None, decay_freq=5, reduce_on_plateau=False): + def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-3, + gen_learning_rate=1.e-3, save_freq=10, lr_decay=None, decay_freq=5, + reduce_on_plateau=False): ''' Training driver which loads the optimizer and calls the `train_batch` method. Also handles checkpoint saving @@ -126,7 +122,7 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, ------ train_data : DataLoader object Training data that is mapped using the DataLoader or - MmapDataLoader object defined in patchgan/io.py + MmapDataLoader object defined in io.py val_data : DataLoader object Validation data loaded in using the DataLoader or MmapDataLoader object @@ -155,10 +151,8 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, ''' if (lr_decay is not None) and not reduce_on_plateau: - gen_lr = gen_learning_rate * \ - (lr_decay)**((self.start - 1) / (decay_freq)) - dsc_lr = dsc_learning_rate * \ - (lr_decay)**((self.start - 1) / (decay_freq)) + gen_lr = gen_learning_rate * (lr_decay)**((self.start - 1) / (decay_freq)) + dsc_lr = dsc_learning_rate * (lr_decay)**((self.start - 1) / (decay_freq)) else: gen_lr = gen_learning_rate dsc_lr = dsc_learning_rate @@ -168,23 +162,18 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, self.neptune_config['model/parameters/dsc_learning_rate'] = dsc_lr self.neptune_config['model/parameters/start'] = self.start self.neptune_config['model/parameters/n_epochs'] = epochs - self.neptune_config['model/parameters/fc_beta'] = self.fc_beta - self.neptune_config['model/parameters/fc_gamma'] = self.fc_gamma - self.neptune_config['model/parameters/disc_alpha'] = self.disc_alpha # create the Adam optimzers - self.gen_optimizer = optim.Adam( - self.generator.parameters(), lr=gen_lr) - self.disc_optimizer = optim.Adam( - self.discriminator.parameters(), lr=dsc_lr) + self.gen_optimizer = optim.NAdam( + self.generator.parameters(), lr=gen_lr, betas=(0.9, 0.999)) + self.disc_optimizer = optim.NAdam( + self.discriminator.parameters(), lr=dsc_lr, betas=(0.9, 0.999)) # set up the learning rate scheduler with exponential lr decay if reduce_on_plateau: gen_scheduler = ReduceLROnPlateau(self.gen_optimizer, verbose=True) - dsc_scheduler = ReduceLROnPlateau( - self.disc_optimizer, verbose=True) - if self.neptune_config is not None: - self.neptune_config['model/parameters/scheduler'] = 'ReduceLROnPlateau' + dsc_scheduler = ReduceLROnPlateau(self.disc_optimizer, verbose=True) + self.neptune_config['model/parameters/scheduler'] = 'ReduceLROnPlateau' elif lr_decay is not None: gen_scheduler = ExponentialLR(self.gen_optimizer, gamma=lr_decay) dsc_scheduler = ExponentialLR(self.disc_optimizer, gamma=lr_decay) @@ -210,9 +199,10 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, print("-------------------------------------------------------") # batch loss data - pbar = tqdm.tqdm(train_data, desc='Training: ', dynamic_ncols=True, ascii=True) + pbar = tqdm.tqdm(train_data, desc='Training: ', dynamic_ncols=True) - train_data.shuffle() + if hasattr(train_data, 'shuffle'): + train_data.shuffle() # set to training mode self.generator.train() @@ -220,10 +210,10 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, losses = defaultdict(list) # loop through the training data - for i, (input_img, target_img) in enumerate(pbar): + for i, (input_img, target_mask) in enumerate(pbar): # train on this batch - batch_loss = self.batch(input_img, target_img, train=True) + batch_loss = self.batch(input_img, target_mask, train=True) # append the current batch loss loss_mean = {} @@ -231,8 +221,7 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, losses[key].append(value) loss_mean[key] = np.mean(losses[key], axis=0) - loss_str = " ".join( - [f"{key}: {value:.2e}" for key, value in loss_mean.items()]) + loss_str = " ".join([f"{key}: {value:.2e}" for key, value in loss_mean.items()]) pbar.set_postfix_str(loss_str) @@ -242,31 +231,28 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-4, if self.neptune_config is not None: self.neptune_config['train/gen_loss'].append(loss_mean['gen']) - self.neptune_config['train/disc_loss'].append( - loss_mean['disc']) + self.neptune_config['train/disc_loss'].append(loss_mean['disc']) # validate every `validation_freq` epochs self.discriminator.eval() self.generator.eval() - pbar = tqdm.tqdm(val_data, desc='Validation: ', ascii=True, dynamic_ncols=True) + pbar = tqdm.tqdm(val_data, desc='Validation: ') - val_data.shuffle() + if hasattr(val_data, 'shuffle'): + val_data.shuffle() losses = defaultdict(list) # loop through the training data - for i, (input_img, target_img) in enumerate(pbar): - - # train on this batch - with torch.no_grad(): - batch_loss = self.batch(input_img, target_img, train=False) + for i, (input_img, target_mask) in enumerate(pbar): + # validate on this batch + batch_loss = self.batch(input_img, target_mask, train=False) loss_mean = {} for key, value in batch_loss.items(): losses[key].append(value) loss_mean[key] = np.mean(losses[key], axis=0) - loss_str = " ".join( - [f"{key}: {value:.2e}" for key, value in loss_mean.items()]) + loss_str = " ".join([f"{key}: {value:.2e}" for key, value in loss_mean.items()]) pbar.set_postfix_str(loss_str) @@ -331,3 +317,25 @@ def load(self, generator_save, discriminator_save): dfname = discriminator_save.split('/')[-1] print( f"Loaded checkpoints from {gfname} and {dfname}") + +# custom weights initialization called on generator and discriminator +# scaling here means std + + +def weights_init(net, init_type='normal', scaling=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv')) != -1: + torch.nn.init.xavier_uniform_(m.weight.data) + # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + elif classname.find('BatchNorm') != -1: + torch.nn.init.xavier_uniform_(m.weight.data, 1.0) + torch.nn.init.constant_(m.bias.data, 0.0) diff --git a/patchgan/unet.py b/patchgan/unet.py index 29c0096..dd55fd4 100755 --- a/patchgan/unet.py +++ b/patchgan/unet.py @@ -1,226 +1,8 @@ import torch -import functools from torch import nn -from torch.nn.parameter import Parameter from collections import OrderedDict from itertools import chain - - -class Transferable(): - def __init__(self): - super(Transferable, self).__init__() - - def load_transfer_data(self, checkpoint, verbose=False): - state_dict = torch.load(checkpoint, map_location=next(self.parameters()).device) - own_state = self.state_dict() - state_names = list(own_state.keys()) - count = 0 - for name, param in state_dict.items(): - if isinstance(param, Parameter): - # backwards compatibility for serialized parameters - param = param.data - - # find the weight with the closest name to this - sub_name = '.'.join(name.split('.')[-2:]) - own_state_name = [n for n in state_names if sub_name in n] - if len(own_state_name) == 1: - own_state_name = own_state_name[0] - else: - if verbose: - print(f'{name} not found') - continue - - if param.shape == own_state[own_state_name].data.shape: - own_state[own_state_name].copy_(param) - count += 1 - - if count == 0: - print("WARNING: Could not transfer over any weights!") - else: - print(f"Loaded weights for {count} layers") - - -class UnetSkipConnectionBlock(nn.Module): - """Defines the Unet submodule with skip connection. - X -------------------identity---------------------- - |-- downsampling -- |submodule| -- upsampling --| - """ - - def __init__(self, outer_nc, inner_nc, input_nc=None, - submodule=None, outermost=False, innermost=False, - activation='tanh', norm_layer=nn.BatchNorm2d, use_dropout=False, layer=1): - """Construct a Unet submodule with skip connections. - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - """ - super(UnetSkipConnectionBlock, self).__init__() - self.outermost = outermost - if input_nc is None: - input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, - stride=2, padding=1, bias=False) - - if activation == 'tanh': - downact = nn.Tanh() - upact = nn.Tanh() - else: - downact = nn.LeakyReLU(0.2, True) - upact = nn.ReLU(True) - - downnorm = norm_layer(inner_nc) - upnorm = norm_layer(outer_nc) - - if outermost: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1) - - if outer_nc == 1: - upact = nn.Sigmoid() - else: - upact = nn.Softmax(dim=1) - down = OrderedDict([(f'DownConv{layer}', downconv), - (f'DownAct{layer}', downact), - (f'DownNorm{layer}', downnorm)]) - up = OrderedDict([(f'UpConv{layer}', upconv), - (f'UpAct{layer}', upact)]) - if use_dropout: - model = OrderedDict(chain(down.items(), - [(f'SubModule{layer}', submodule)], - up.items())) - else: - model = OrderedDict(chain(down.items(), - [(f'SubModule{layer}', - submodule)], up.items())) # down + [submodule] + up - elif innermost: - upconv = nn.ConvTranspose2d(inner_nc, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=False) - down = OrderedDict([(f'DownConv{layer}', downconv), - (f'DownAct{layer}', downact)]) - up = OrderedDict([(f'UpConv{layer}', upconv), - (f'UpAct{layer}', upact), - (f'UpNorm{layer}', upnorm)]) - model = OrderedDict(chain(down.items(), - up.items())) # down + [submodule] + up - else: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=False) - down = OrderedDict([(f'DownConv{layer}', downconv), - (f'DownAct{layer}', downact), - (f'DownNorm{layer}', downnorm)]) - up = OrderedDict([(f'UpConv{layer}', upconv), - (f'UpAct{layer}', upact), - (f'UpNorm{layer}', upnorm)]) - - if use_dropout: - model = OrderedDict(chain(down.items(), - [(f'EncDropout{layer}', - nn.Dropout(0.5))], - [(f'SubModule{layer}', submodule)], - up.items(), - [(f'DecDropout{layer}', nn.Dropout(0.5))])) - else: - model = OrderedDict(chain(down.items(), - [(f'SubModule{layer}', submodule)], - up.items())) # down + [submodule] + up - - self.model = nn.Sequential(model) - - def forward(self, x): - if self.outermost: - return self.model(x) - else: # add skip connections - return torch.cat([x, self.model(x)], 1) - - -def get_norm_layer(): - """Return a normalization layer - For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). - """ - norm_type = 'batch' - if norm_type == 'batch': - norm_layer = functools.partial( - nn.BatchNorm2d, affine=True, track_running_stats=True) - return norm_layer - - -# custom weights initialization called on generator and discriminator -# scaling here means std -def weights_init(net, init_type='normal', scaling=0.02): - """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might - work better for some applications. Feel free to try yourself. - """ - def init_func(m): # define the initialization function - classname = m.__class__.__name__ - if hasattr(m, 'weight') and (classname.find('Conv')) != -1: - torch.nn.init.normal_(m.weight.data, 0.0, scaling) - # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, scaling) - torch.nn.init.constant_(m.bias.data, 0.0) - - net.apply(init_func) # apply the initialization function - - -class UnetGenerator(nn.Module, Transferable): - """Create a Unet-based generator""" - - def __init__(self, input_nc, output_nc, nf=64, - norm_layer=nn.BatchNorm2d, use_dropout=False, - activation='tanh'): - """Construct a Unet generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, - image of size 128x128 will become of size 1x1 # at the bottleneck - nf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - We construct the U-Net from the innermost layer to the outermost layer. - It is a recursive process. - """ - super(UnetGenerator, self).__init__() - # construct unet structure - unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, activation=activation, - submodule=None, norm_layer=norm_layer, - use_dropout=use_dropout, innermost=True, layer=7) - - # add intermediate layers with ngf * 8 filters - unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, activation=activation, - submodule=unet_block, norm_layer=norm_layer, - use_dropout=use_dropout, layer=6) - unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, activation=activation, - submodule=unet_block, norm_layer=norm_layer, - use_dropout=use_dropout, layer=5) - - # gradually reduce the number of filters from nf * 8 to nf - unet_block = UnetSkipConnectionBlock(nf * 4, nf * 8, input_nc=None, activation=activation, - submodule=unet_block, norm_layer=norm_layer, layer=4) - unet_block = UnetSkipConnectionBlock(nf * 2, nf * 4, input_nc=None, activation=activation, - submodule=unet_block, norm_layer=norm_layer, layer=3) - unet_block = UnetSkipConnectionBlock(nf, nf * 2, input_nc=None, activation=activation, - submodule=unet_block, norm_layer=norm_layer, layer=2) - self.model = UnetSkipConnectionBlock(output_nc, nf, input_nc=input_nc, activation=activation, - submodule=unet_block, outermost=True, - norm_layer=norm_layer, layer=1) # add the outermost layer - - def forward(self, input): - """Standard forward""" - return self.model(input) +from .transfer import Transferable class DownSampleBlock(nn.Module): @@ -291,7 +73,7 @@ def forward(self, x): class UNet(nn.Module, Transferable): def __init__(self, input_nc, output_nc, nf=64, - norm_layer=nn.BatchNorm2d, use_dropout=False, + norm_layer=nn.InstanceNorm2d, use_dropout=False, activation='tanh', final_act='softmax'): super(UNet, self).__init__() From 4a0b1a338e01c0ddd0794b675d99739fcefb8a92 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:25:11 -0500 Subject: [PATCH 02/31] updating losses --- patchgan/losses.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/patchgan/losses.py b/patchgan/losses.py index 2420481..d120ce6 100644 --- a/patchgan/losses.py +++ b/patchgan/losses.py @@ -6,9 +6,6 @@ def tversky(y_true, y_pred, beta, batch_mean=True): tp = torch.sum(y_true * y_pred, axis=(1, 2, 3)) fn = torch.sum((1. - y_pred) * y_true, axis=(1, 2, 3)) fp = torch.sum(y_pred * (1. - y_true), axis=(1, 2, 3)) - # tversky = reduce_mean(tp)/(reduce_mean(tp) + - # beta*reduce_mean(fn) + - # (1. - beta)*reduce_mean(fp)) tversky = tp /\ (tp + beta * fn + (1. - beta) * fp) @@ -20,21 +17,9 @@ def tversky(y_true, y_pred, beta, batch_mean=True): def fc_tversky(y_true, y_pred, beta, gamma=0.75, batch_mean=True): smooth = 1 - ''' - y_true_pos = torch.flatten(y_true) - y_pred_pos = torch.flatten(y_pred) - true_pos = torch.sum(y_true_pos * y_pred_pos, axis=(1,2,3)) - false_neg = torch.sum(y_true_pos * (1-y_pred_pos), axis=(1,2,3)) - false_pos = torch.sum((1-y_true_pos)*y_pred_pos, axis=(1,2,3)) - - answer = (true_pos + smooth)/(true_pos + beta*false_neg + (1-beta)*false_pos + smooth) - ''' tp = torch.sum(y_true * y_pred, axis=(1, 2, 3)) fn = torch.sum((1. - y_pred) * y_true, axis=(1, 2, 3)) fp = torch.sum(y_pred * (1. - y_true), axis=(1, 2, 3)) - # tversky = reduce_mean(tp)/(reduce_mean(tp) + - # beta*reduce_mean(fn) + - # (1. - beta)*reduce_mean(fp)) tversky = (tp + smooth) /\ (tp + beta * fn + (1. - beta) * fp + smooth) @@ -47,14 +32,4 @@ def fc_tversky(y_true, y_pred, beta, gamma=0.75, batch_mean=True): # alias -adv_loss = nn.BCELoss() - - -def generator_loss(generated_img, target_img): - gen_loss = fc_tversky(target_img, generated_img, beta=0.7, gamma=0.75) - return gen_loss - - -def discriminator_loss(output, label): - disc_loss = adv_loss(output, label) - return disc_loss +bce_loss = nn.BCELoss() From 9a4cb6778b752c04b0c1a21218c372900b2a553d Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:25:22 -0500 Subject: [PATCH 03/31] deleting unnecessary utils --- patchgan/utils.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 patchgan/utils.py diff --git a/patchgan/utils.py b/patchgan/utils.py deleted file mode 100644 index 627f0a2..0000000 --- a/patchgan/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np - -'''Data Preprocessing''' -def crop_images_batch(inputs, target): - batch_size = len(inputs) - new_imgs = np.zeros([batch_size, 4, 256, 256]) - new_masks = np.zeros([batch_size, 1, 256, 256]) - height, width = 350, 350 - xs, ys = np.random.uniform(low=0,high=int(height-256),size=batch_size), np.random.uniform(low=0,high=int(width-256),size=batch_size) - for i in range(batch_size): - x, y = xs[i], ys[i] - start_x, end_x = int(x), int(x)+256 - new_imgs[i,:,:,:] = inputs[i,:,start_x : end_x, int(y): int(y)+256]/255. - new_masks[i,:,:,:] = target[i,:,int(x): int(x)+256,int(y): int(y)+256] - - return new_imgs, new_masks - From 4d906d0f60e47b19c8d3f26f37c56a10efa0d814 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:25:47 -0500 Subject: [PATCH 04/31] removing old training scripts --- train_ff.py | 45 ------------------------------------------- train_ff_nc.py | 48 ---------------------------------------------- transfer_FC.py | 52 -------------------------------------------------- 3 files changed, 145 deletions(-) delete mode 100644 train_ff.py delete mode 100644 train_ff_nc.py delete mode 100644 transfer_FC.py diff --git a/train_ff.py b/train_ff.py deleted file mode 100644 index d4fd5c7..0000000 --- a/train_ff.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -from torchinfo import summary -from patchgan.unet import UNet, Discriminator, get_norm_layer -from patchgan.io import MmapDataGenerator -from patchgan.trainer import Trainer, device - -# nc_file = './data/FloatingForest/data/trainval.nc' -mmap_imgs = '../shuffled_data_b_cropped/train_aug_imgs.npy' -mmap_mask = '../shuffled_data_b_cropped/train_aug_mask.npy' -batch_size = 48 -traindata = MmapDataGenerator(mmap_imgs, mmap_mask, batch_size) - -# nc_file_val = './data/FloatingForest/data/test.nc' -mmap_imgs_val = '../shuffled_data_b_cropped/valid_aug_imgs.npy' -mmap_mask_val = '../shuffled_data_b_cropped/valid_aug_mask.npy' -batch_size = 48 -val_dl = MmapDataGenerator(mmap_imgs_val, mmap_mask_val, batch_size) - -GEN_FILTS = 32 -DISC_FILTS = 16 -ACTIV = 'relu' - -IN_NC = 4 -OUT_NC = 1 - -norm_layer = get_norm_layer() - -# create the generator -generator = UNet(IN_NC, OUT_NC, GEN_FILTS, norm_layer=norm_layer, - use_dropout=False, activation=ACTIV).to(device) - -# create the discriminator -discriminator = Discriminator(IN_NC + OUT_NC, DISC_FILTS, n_layers=3, norm_layer=norm_layer).to(device) - -summary(generator, [1, 4, 256, 256]) - -# create the training object and start training -trainer = Trainer(generator, discriminator, - f'checkpoints-{GEN_FILTS}-{DISC_FILTS}-{ACTIV}/') - -G_loss, D_loss = trainer.train(traindata, val_dl, 200, gen_learning_rate=5.e-4, - dsc_learning_rate=1.e-4, lr_decay=0.95) - -# save the loss history -np.savez('loss_history.npz', D_loss=D_loss, G_loss=G_loss) diff --git a/train_ff_nc.py b/train_ff_nc.py deleted file mode 100644 index 8f78b34..0000000 --- a/train_ff_nc.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import torch -from torch import optim -from torch import nn -from torchinfo import summary -import tqdm -from patchgan import * - -nc_file = './data/FloatingForest/data/trainval.nc' -batch_size= 48 -traindata = DataGenerator(nc_file, batch_size) - -nc_file_val = './data/FloatingForest/data/test.nc' -batch_size= 48 -val_dl = DataGenerator(nc_file, batch_size) - - -GEN_FILTS = 32 -DISC_FILTS = 32 -ACTIV = 'tanh' - -# create the generator -norm_layer = get_norm_layer() -generator = UnetGenerator(4, 1, GEN_FILTS, norm_layer=norm_layer, - use_dropout=False, activation=ACTIV) -generator.apply(weights_init) -generator = generator.cuda() - -# create the discriminator -discriminator = Discriminator(5, DISC_FILTS, n_layers=3, norm_layer=norm_layer).cuda() -discriminator.apply(weights_init) - -summary(generator, [1, 4, 256, 256]) - -# create the training object and start training -trainer = Trainer(generator, discriminator, - f'checkpoints-{GEN_FILTS}-{DISC_FILTS}-{ACTIV}/', crop=False) - -try: - trainer.load_last_checkpoint() -except Exception as e: - raise(e) - -G_loss_plot, D_loss_plot = trainer.train(traindata, val_dl, 200, learning_rate=1.e-3) - -# save the loss history -np.savez('loss_history.npz', D_loss = D_loss_plot ,G_loss = G_loss_plot) - diff --git a/transfer_FC.py b/transfer_FC.py deleted file mode 100644 index b087e0d..0000000 --- a/transfer_FC.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -from torchinfo import summary -from patchgan.unet import UNet, Discriminator, get_norm_layer -from patchgan.io import MmapDataGenerator -from patchgan.trainer import Trainer, device - -# nc_file = './data/FloatingForest/data/trainval.nc' -mmap_imgs = '../shuffled_data_b_cropped/train_aug_imgs.npy' -mmap_mask = '../shuffled_data_b_cropped/train_aug_mask.npy' -batch_size = 48 -traindata = MmapDataGenerator(mmap_imgs, mmap_mask, batch_size) - -# nc_file_val = './data/FloatingForest/data/test.nc' -mmap_imgs_val = '../shuffled_data_b_cropped/valid_aug_imgs.npy' -mmap_mask_val = '../shuffled_data_b_cropped/valid_aug_mask.npy' -batch_size = 48 -val_dl = MmapDataGenerator(mmap_imgs_val, mmap_mask_val, batch_size) - -GEN_FILTS = 32 -DISC_FILTS = 16 -ACTIV = 'relu' - -IN_NC = 4 -OUT_NC = 1 - -norm_layer = get_norm_layer() - -# create the generator -generator = UNet(IN_NC, OUT_NC, GEN_FILTS, norm_layer=norm_layer, - use_dropout=False, activation=ACTIV).to(device) - -# create the discriminator -discriminator = Discriminator(IN_NC + OUT_NC, DISC_FILTS, n_layers=3, norm_layer=norm_layer).to(device) - -summary(generator, [1, 4, 256, 256]) - - -# create the training object and start training -trainer = Trainer(generator, discriminator, - f'checkpoints-{GEN_FILTS}-{DISC_FILTS}-{ACTIV}/') -generator.load_transfer_data( - '/home/fortson/manth145/codes/patchGAN_FF_ImageNet/patchGAN/FC_checkpoints-64-32-relu/generator_epoch_50.pth' -) -discriminator.load_transfer_data( - '/home/fortson/manth145/codes/patchGAN_FF_ImageNet/patchGAN/FC_checkpoints-64-32-relu/discriminator_epoch_50.pth' -) - -G_loss, D_loss = trainer.train(traindata, val_dl, 200, gen_learning_rate=5.e-4, - dsc_learning_rate=1.e-4, lr_decay=0.95) - -# save the loss history -np.savez('loss_history.npz', D_loss=D_loss, G_loss=G_loss) From 7bc6de0f02680ee409e085e2b83b3a0e815cc430 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:26:01 -0500 Subject: [PATCH 05/31] added transfer learning module --- patchgan/transfer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 patchgan/transfer.py diff --git a/patchgan/transfer.py b/patchgan/transfer.py new file mode 100644 index 0000000..3fedbff --- /dev/null +++ b/patchgan/transfer.py @@ -0,0 +1,15 @@ +from torch.nn.parameter import Parameter + + +class Transferable(): + def __init__(self): + super(Transferable, self).__init__() + + def load_transfer_data(self, state_dict): + own_state = self.state_dict() + for name, param in state_dict.items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + if param.shape == own_state[name].data.shape: + own_state[name].copy_(param) From 30f0ce5c11f796e2cef431627a361b5118beb714 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:41:02 -0500 Subject: [PATCH 06/31] adding FloatingForests and COCOStuff datagenerators --- patchgan/io.py | 152 ++++++++++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 66 deletions(-) diff --git a/patchgan/io.py b/patchgan/io.py index eee0ca2..3deedf4 100644 --- a/patchgan/io.py +++ b/patchgan/io.py @@ -1,100 +1,120 @@ +import torch +from torch.utils.data import Dataset import numpy as np -import netCDF4 as nc +import glob +import os +from torchvision.io import read_image, ImageReadMode +from torchvision.transforms import Compose, Resize, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip +from einops import rearrange +import rasterio -class DataGenerator(): - def __init__(self, nc_file, batch_size, indices=None): - self.nc_file = nc_file +class FloatingForestsDataset(Dataset): + augmentation = None - self.batch_size = batch_size + def __init__(self, imgfolder, maskfolder, size=256, augmentation='randomcrop'): + self.images = sorted(glob.glob(os.path.join(imgfolder, "*.tif"))) + self.masks = sorted(glob.glob(os.path.join(maskfolder, "*.tif"))) + self.size = size - if indices is not None: - self.indices = indices - self.ndata = len(indices) - else: - with nc.Dataset(nc_file, 'r') as dset: - self.ndata = int(dset.dimensions['file'].size) - self.indices = np.arange(self.ndata) + self.image_ids = [int(os.path.basename(image).replace('.tif', '')) for image in self.images] + self.mask_ids = [int(os.path.basename(image).replace('.tif', '')) for image in self.masks] - print(f"Found data with {self.ndata} images") + assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!" - def shuffle(self): - np.random.shuffle(self.indices) + if augmentation == 'randomcrop': + self.augmentation = RandomCrop(size=(size, size)) + elif augmentation == 'randomcrop+flip': + self.augmentation = Compose([ + RandomCrop(size=(size, size)), + RandomHorizontalFlip(0.25), + RandomVerticalFlip(0.25) + ]) + + print(f"Loaded {len(self)} images") def __len__(self): - return self.ndata // self.batch_size + return len(self.images) def __getitem__(self, index): - batch_indices = self.indices[index * - self.batch_size:(index + 1) * self.batch_size] - - if len(batch_indices) < 1: - raise StopIteration + image_file = self.images[index] + mask_file = self.masks[index] - return self.get_from_indices(batch_indices) + img = rasterio.open(image_file) + img_stacked = np.dstack([img.read(1), img.read(2), img.read(3), img.read(4)]) - def get_from_indices(self, indices): - with nc.Dataset(self.nc_file, 'r') as dset: - imgs = dset.variables['imgs'][indices, - :, :, :].astype(float) / 255. - mask = dset.variables['mask'][indices, :, :].astype(float) + # clean up artifacts in the data + img_stacked[np.abs(img_stacked) > 1.e20] = 0 + img_stacked[np.isnan(img_stacked)] = 0 - return imgs, np.expand_dims(mask, axis=1) + # remove negative signal + img_stacked = img_stacked - np.percentile(img_stacked.flatten(), 2) - def get_meta(self, key, index=None): - if index is not None: - batch_indices = self.indices[index * - self.batch_size:(index + 1) * self.batch_size] - else: - batch_indices = self.indices + norm = np.nansum(img_stacked, axis=-1, keepdims=True) + img_stacked = img_stacked / (norm + 1.e-3) - with nc.Dataset(self.nc_file, 'r') as dset: - var = dset.variables[key][batch_indices] + # normalize the image per-pixel + img = np.clip(img_stacked, 0, 1) + img[~np.isfinite(img)] = 0. - return var + # add the mask so we can crop it + mask = rasterio.open(mask_file).read(1) + mask[mask < 0.] = 0. + data_stacked = np.concatenate((img, np.expand_dims(mask, -1)), axis=-1) + data_stacked = rearrange(torch.Tensor(data_stacked), "h w c -> c h w") + if self.augmentation is not None: + data_stacked = self.augmentation(data_stacked) -class MmapDataGenerator(DataGenerator): - def __init__(self, img_file, mask_file, batch_size, indices=None): - self.img_file = img_file - self.mask_file = mask_file + return data_stacked[:4, :, :], data_stacked[4, :, :].unsqueeze(0) - self.batch_size = batch_size - self.imgs = np.load(self.img_file, mmap_mode='r') - self.mask = np.load(self.mask_file, mmap_mode='r') +class COCOStuffDataset(Dataset): + augmentation = None - self.ndata = len(self.imgs) + def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='resize'): + self.images = np.asarray(sorted(glob.glob(os.path.join(imgfolder, "*.jpg")))) + self.masks = np.asarray(sorted(glob.glob(os.path.join(maskfolder, "*.png")))) + self.size = size + self.labels = np.sort(labels) - if indices is not None: - self.indices = indices - else: - self.indices = np.arange(self.ndata) + self.image_ids = [int(os.path.basename(image).replace('.jpg', '')) for image in self.images] + self.mask_ids = [int(os.path.basename(image).replace('.png', '')) for image in self.masks] - print(f"Found data with {self.ndata} images") + assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!" - def get_from_indices(self, batch_indices): - img = self.imgs[batch_indices, :].astype(float) / 255 - mask = self.mask[batch_indices, :].astype(float) + if augmentation == 'randomcrop': + self.augmentation = Resize(size=(size, size), antialias=None) + elif augmentation == 'randomcrop+flip': + self.augmentation = Compose([ + Resize(size=(size, size), antialias=None), + RandomHorizontalFlip(0.25), + RandomVerticalFlip(0.25), + ]) - return img, mask + print(f"Loaded {len(self)} images") + def __len__(self): + return len(self.images) -def create_generators(generator, val_split=0.1, **kwargs): - gen = generator(**kwargs) + def __getitem__(self, index): + image_file = self.images[index] + mask_file = self.masks[index] - ndata = gen.ndata + img = read_image(image_file, ImageReadMode.RGB) / 255. + labels = read_image(mask_file, ImageReadMode.GRAY) + 1 - print(f"Creating generators from {ndata} images") + # add the mask so we can crop it + data_stacked = torch.cat((img, labels), dim=0) - inds = np.arange(ndata) - np.random.shuffle(inds) + if self.augmentation is not None: + data_stacked = self.augmentation(data_stacked) - val_split_ind = int(ndata * val_split) - val_ind = inds[:val_split_ind] - training_ind = inds[val_split_ind:] + img = data_stacked[:3, :] + labels = data_stacked[3, :] - train_data = generator(**kwargs, indices=training_ind) - val_data = generator(**kwargs, indices=val_ind) + mask = torch.zeros((len(self.labels), labels.shape[0], labels.shape[1])) + for i, label in enumerate(self.labels): + mask[i, labels == label] = 1 - return train_data, val_data + return img, mask From f4c80004e0a211b9be310518af8b3343eaf53eb3 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:41:17 -0500 Subject: [PATCH 07/31] adding setup script --- setup.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4901980 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +from setuptools import setup, find_packages + +setup( + name='patchGAN', + version='0.1', + description='patchGAN image segmentation model in PyTorch', + license='GNU General Public License v3', + url='https://github.com/ramanakumars/patchGAN', + author='Kameswara Mantha, Ramanakumar Sankar, Lucy Fortson', + author_email='manth145@umn.edu, rsankar@umn.edu, lfortson@umn.edu', + packages=find_packages(), + install_requires=[ + 'numpy>=1.21.0,<1.25.2', + 'torch>=1.13.0,<=2.0.1', + 'torchvision>=0.14.0<=0.15.0', + 'tqdm>=4.62.3,<=4.65.0', + ] +) From 30a3658da53ab9fa778c73ed8e8223e6190ea620 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 18:41:31 -0500 Subject: [PATCH 08/31] adding test training script --- train.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 train.py diff --git a/train.py b/train.py new file mode 100644 index 0000000..0fec57b --- /dev/null +++ b/train.py @@ -0,0 +1,91 @@ +import torch +from torchinfo import summary +from patchgan.unet import UNet +from patchgan.disc import Discriminator +from patchgan.io import FloatingForestsDataset, COCOStuffDataset +from patchgan.trainer import Trainer +from torch.utils.data import DataLoader +import yaml +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + prog='PatchGAN', + description='Train the PatchGAN architecture' + ) + + parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') + parser.add_argument('-b', '--batch_size', default=16, type=int, help='Number of images per batch') + parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') + parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') + parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') + parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + + args = parser.parse_args() + + if args.device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + elif args.device in ['cuda', 'cpu']: + device = args.device + + with open(args.config_file, 'r') as infile: + config = yaml.safe_load(infile) + + train_data_paths = config['train_data'] + val_data_paths = config['validation_data'] + + size = config['dataset'].get('size', 256) + augmentation = config['dataset'].get('augmentation', 'randomcrop') + + dataset_kwargs = {} + if config['dataset']['type'] == 'FloatingForests': + Dataset = FloatingForestsDataset + in_channels = 4 + out_channels = 1 + elif config['dataset']['type'] == 'COCOStuff': + Dataset = COCOStuffDataset + in_channels = 3 + labels = config['dataset'].get('labels', [1]) + out_channels = len(labels) + dataset_kwargs['labels'] = labels + + train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + + gen_filts = config['model_params']['gen_filts'] + disc_filts = config['model_params']['disc_filts'] + n_disc_layers = config['model_params']['n_disc_layers'] + activation = config['model_params']['activation'] + use_dropout = config['model_params'].get('use_dropout', True) + final_activation = config['model_params'].get('final_activation', 'sigmoid') + + train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + + # create the generator + generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) + + # create the discriminator + discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) + + if args.summary: + summary(generator, [1, in_channels, size, size]) + summary(discriminator, [1, in_channels + out_channels, size, size]) + + checkpoint_path = config.get('checkpoint_path', './checkpoints/') + + trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) + + if config.get('load_last_checkpoint', False): + trainer.load_last_checkpoint() + + train_params = config['train_params'] + + trainer.loss_type = train_params['loss_type'] + trainer.seg_alpha = train_params['seg_alpha'] + + trainer.train(train_data, val_data, args.n_epochs, + dsc_learning_rate=train_params['disc_learning_rate'], + gen_learning_rate=train_params['gen_learning_rate'], + lr_decay=train_params.get('decay_rate', None), + save_freq=train_params.get('save_freq', 10)) From ba61cb4ca6b718e36706d8e28e32bd938dca24bf Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 19:19:48 -0500 Subject: [PATCH 09/31] adding versioning --- patchgan/__init__.py | 3 ++- patchgan/version.py | 1 + setup.py | 21 ++++++++++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 patchgan/version.py diff --git a/patchgan/__init__.py b/patchgan/__init__.py index f9bfc38..fb52b0a 100644 --- a/patchgan/__init__.py +++ b/patchgan/__init__.py @@ -1,7 +1,8 @@ from .unet import UNet from .disc import Discriminator from .trainer import Trainer +from .version import __version__ __all__ = [ - 'UNet', 'Discriminator', 'Trainer' + 'UNet', 'Discriminator', 'Trainer', '__version__' ] diff --git a/patchgan/version.py b/patchgan/version.py new file mode 100644 index 0000000..11d27f8 --- /dev/null +++ b/patchgan/version.py @@ -0,0 +1 @@ +__version__ = '0.1' diff --git a/setup.py b/setup.py index 4901980..1990351 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,24 @@ from setuptools import setup, find_packages +import os + +here = os.path.abspath(os.path.dirname(__file__)) + +try: + with open(os.path.join(here, 'README.md'), 'r') as fh: + long_description = fh.read() +except FileNotFoundError: + long_description = '' + +version = {} +with open(os.path.join(here, 'patchgan/version.py')) as ver_file: + exec(ver_file.read(), version) setup( name='patchGAN', - version='0.1', + version=version['__version__'], description='patchGAN image segmentation model in PyTorch', + long_description=long_description, + long_description_content_type='text/markdown', license='GNU General Public License v3', url='https://github.com/ramanakumars/patchGAN', author='Kameswara Mantha, Ramanakumar Sankar, Lucy Fortson', @@ -11,8 +26,8 @@ packages=find_packages(), install_requires=[ 'numpy>=1.21.0,<1.25.2', - 'torch>=1.13.0,<=2.0.1', - 'torchvision>=0.14.0<=0.15.0', + 'torch>=1.13.0', + 'torchvision>=0.14.0,<=0.15.0', 'tqdm>=4.62.3,<=4.65.0', ] ) From e45fba376af9e9b997fefa0260180da9cae344a0 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 19:33:37 -0500 Subject: [PATCH 10/31] updated README --- README.md | 27 ++++++++++++++++- train.py | 91 ------------------------------------------------------- 2 files changed, 26 insertions(+), 92 deletions(-) delete mode 100644 train.py diff --git a/README.md b/README.md index 45ac661..fe7c71b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,27 @@ # patchGAN -patchGAN model for image segmentation + +[![PyPI version](https://badge.fury.io/py/patchGAN.svg)](https://badge.fury.io/py/patchGAN) + +UNet-based GAN model for image segmentation using a patch-wise discriminator. +Based on the [pix2pix](https://phillipi.github.io/pix2pix/) model. + +## Installation + +Install the package with pip: +``` +pip install patchgan +``` + +Upgrading existing install: +``` +pip install -U patchgan +``` + +Get the current development branch: +``` +pip install -U git+https://github.com/ramanakumars/patchGAN.git +``` + +## Usage +See `examples/train.py` for an example training script and +`examples/train_coco.yaml` for the corresponding config for the COCO stuff dataset. diff --git a/train.py b/train.py deleted file mode 100644 index 0fec57b..0000000 --- a/train.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from torchinfo import summary -from patchgan.unet import UNet -from patchgan.disc import Discriminator -from patchgan.io import FloatingForestsDataset, COCOStuffDataset -from patchgan.trainer import Trainer -from torch.utils.data import DataLoader -import yaml -import argparse - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - prog='PatchGAN', - description='Train the PatchGAN architecture' - ) - - parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') - parser.add_argument('-b', '--batch_size', default=16, type=int, help='Number of images per batch') - parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') - parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') - parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') - parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") - - args = parser.parse_args() - - if args.device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - elif args.device in ['cuda', 'cpu']: - device = args.device - - with open(args.config_file, 'r') as infile: - config = yaml.safe_load(infile) - - train_data_paths = config['train_data'] - val_data_paths = config['validation_data'] - - size = config['dataset'].get('size', 256) - augmentation = config['dataset'].get('augmentation', 'randomcrop') - - dataset_kwargs = {} - if config['dataset']['type'] == 'FloatingForests': - Dataset = FloatingForestsDataset - in_channels = 4 - out_channels = 1 - elif config['dataset']['type'] == 'COCOStuff': - Dataset = COCOStuffDataset - in_channels = 3 - labels = config['dataset'].get('labels', [1]) - out_channels = len(labels) - dataset_kwargs['labels'] = labels - - train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - - gen_filts = config['model_params']['gen_filts'] - disc_filts = config['model_params']['disc_filts'] - n_disc_layers = config['model_params']['n_disc_layers'] - activation = config['model_params']['activation'] - use_dropout = config['model_params'].get('use_dropout', True) - final_activation = config['model_params'].get('final_activation', 'sigmoid') - - train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) - val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) - - # create the generator - generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) - - # create the discriminator - discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) - - if args.summary: - summary(generator, [1, in_channels, size, size]) - summary(discriminator, [1, in_channels + out_channels, size, size]) - - checkpoint_path = config.get('checkpoint_path', './checkpoints/') - - trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) - - if config.get('load_last_checkpoint', False): - trainer.load_last_checkpoint() - - train_params = config['train_params'] - - trainer.loss_type = train_params['loss_type'] - trainer.seg_alpha = train_params['seg_alpha'] - - trainer.train(train_data, val_data, args.n_epochs, - dsc_learning_rate=train_params['disc_learning_rate'], - gen_learning_rate=train_params['gen_learning_rate'], - lr_decay=train_params.get('decay_rate', None), - save_freq=train_params.get('save_freq', 10)) From 583694010da451dfca23b8f850c70741c0941a6b Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 19:33:46 -0500 Subject: [PATCH 11/31] adding example training script --- examples/train.py | 89 ++++++++++++++++++++++++++++++++++++++++ examples/train_coco.yaml | 29 +++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 examples/train.py create mode 100644 examples/train_coco.yaml diff --git a/examples/train.py b/examples/train.py new file mode 100644 index 0000000..2daca4a --- /dev/null +++ b/examples/train.py @@ -0,0 +1,89 @@ +import torch +from torchinfo import summary +from patchgan.unet import UNet +from patchgan.disc import Discriminator +from patchgan.io import COCOStuffDataset +from patchgan.trainer import Trainer +from torch.utils.data import DataLoader +import yaml +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + prog='PatchGAN', + description='Train the PatchGAN architecture' + ) + + parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') + parser.add_argument('-b', '--batch_size', default=16, type=int, help='Number of images per batch') + parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') + parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') + parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') + parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + + args = parser.parse_args() + + if args.device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + elif args.device in ['cuda', 'cpu']: + device = args.device + + with open(args.config_file, 'r') as infile: + config = yaml.safe_load(infile) + + train_data_paths = config['train_data'] + val_data_paths = config['validation_data'] + + size = config['dataset'].get('size', 256) + augmentation = config['dataset'].get('augmentation', 'randomcrop') + + dataset_kwargs = {} + if config['dataset']['type'] == 'COCOStuff': + Dataset = COCOStuffDataset + in_channels = 3 + labels = config['dataset'].get('labels', [1]) + out_channels = len(labels) + dataset_kwargs['labels'] = labels + else: + raise ValueError(f"{config['dataset']['type']} is unknown!") + + train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + + gen_filts = config['model_params']['gen_filts'] + disc_filts = config['model_params']['disc_filts'] + n_disc_layers = config['model_params']['n_disc_layers'] + activation = config['model_params']['activation'] + use_dropout = config['model_params'].get('use_dropout', True) + final_activation = config['model_params'].get('final_activation', 'sigmoid') + + train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + + # create the generator + generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) + + # create the discriminator + discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) + + if args.summary: + summary(generator, [1, in_channels, size, size]) + summary(discriminator, [1, in_channels + out_channels, size, size]) + + checkpoint_path = config.get('checkpoint_path', './checkpoints/') + + trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) + + if config.get('load_last_checkpoint', False): + trainer.load_last_checkpoint() + + train_params = config['train_params'] + + trainer.loss_type = train_params['loss_type'] + trainer.seg_alpha = train_params['seg_alpha'] + + trainer.train(train_data, val_data, args.n_epochs, + dsc_learning_rate=train_params['disc_learning_rate'], + gen_learning_rate=train_params['gen_learning_rate'], + lr_decay=train_params.get('decay_rate', None), + save_freq=train_params.get('save_freq', 10)) diff --git a/examples/train_coco.yaml b/examples/train_coco.yaml new file mode 100644 index 0000000..f5a9a1f --- /dev/null +++ b/examples/train_coco.yaml @@ -0,0 +1,29 @@ +dataset: + type: COCOStuff + augmentation: randomcrop+flip + size: 256 +train_data: + images: /d1/rsankar/data/COCOstuff/train2017 + masks: /d1/rsankar/data/COCOstuff/train2017 + labels: [1, 2, 3, 4, 5, 6, 7] +validation_data: + images: /d1/rsankar/data/COCOstuff/val2017 + masks: /d1/rsankar/data/COCOstuff/val2017 + labels: [1, 2, 3, 4, 5, 6, 7] +model_params: + gen_filts: 32 + disc_filts: 16 + activation: relu + use_dropout: True + final_activation: sigmoid + n_disc_layers: 5 +checkpoint_path: ./checkpoints/checkpoint-COCO/ +load_last_checkpoint: True +train_params: + loss_type: weighted_bce + seg_alpha: 200 + gen_learning_rate: 1.e-3 + disc_learning_rate: 1.e-3 + decay_rate: 0.95 + save_freq: 5 + From 7bf789f7622ff53e1d535ebb7432fa10ef896323 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 20:04:02 -0500 Subject: [PATCH 12/31] adding training command line argument --- README.md | 9 +++-- patchgan/train.py | 94 +++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 5 +++ 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 patchgan/train.py diff --git a/README.md b/README.md index fe7c71b..b04b09e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,9 @@ Get the current development branch: pip install -U git+https://github.com/ramanakumars/patchGAN.git ``` -## Usage -See `examples/train.py` for an example training script and -`examples/train_coco.yaml` for the corresponding config for the COCO stuff dataset. +## Training +You can train the patchGAN model with a config file and the `patchgan_train` command: +``` +patchgan_train --config_file train_coco.yaml --n_epochs 100 --batch_size 16 +``` +See `examples/train_coco.yaml` for the corresponding config for the COCO stuff dataset. diff --git a/patchgan/train.py b/patchgan/train.py new file mode 100644 index 0000000..db24ac0 --- /dev/null +++ b/patchgan/train.py @@ -0,0 +1,94 @@ +import torch +from torchinfo import summary +from patchgan.unet import UNet +from patchgan.disc import Discriminator +from patchgan.io import COCOStuffDataset +from patchgan.trainer import Trainer +from torch.utils.data import DataLoader +import yaml +import importlib.machinery +import argparse + + +def patchgan_train(): + parser = argparse.ArgumentParser( + prog='PatchGAN', + description='Train the PatchGAN architecture' + ) + + parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') + parser.add_argument('-b', '--batch_size', default=16, type=int, help='Number of images per batch') + parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') + parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') + parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') + parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + + args = parser.parse_args() + + if args.device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + elif args.device in ['cuda', 'cpu']: + device = args.device + + with open(args.config_file, 'r') as infile: + config = yaml.safe_load(infile) + + train_data_paths = config['train_data'] + val_data_paths = config['validation_data'] + + size = config['dataset'].get('size', 256) + augmentation = config['dataset'].get('augmentation', 'randomcrop') + + dataset_kwargs = {} + if config['dataset']['type'] == 'COCOStuff': + Dataset = COCOStuffDataset + in_channels = 3 + labels = config['dataset'].get('labels', [1]) + out_channels = len(labels) + dataset_kwargs['labels'] = labels + else: + spec = importlib.machinery.SourceFileLoader('', 'io.py') + Dataset = spec.load_module().__getattribute__(config['dataset']['type']) + in_channels = 4 + out_channels = 1 + + train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + + gen_filts = config['model_params']['gen_filts'] + disc_filts = config['model_params']['disc_filts'] + n_disc_layers = config['model_params']['n_disc_layers'] + activation = config['model_params']['activation'] + use_dropout = config['model_params'].get('use_dropout', True) + final_activation = config['model_params'].get('final_activation', 'sigmoid') + + train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + + # create the generator + generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) + + # create the discriminator + discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) + + if args.summary: + summary(generator, [1, in_channels, size, size]) + summary(discriminator, [1, in_channels + out_channels, size, size]) + + checkpoint_path = config.get('checkpoint_path', './checkpoints/') + + trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) + + if config.get('load_last_checkpoint', False): + trainer.load_last_checkpoint() + + train_params = config['train_params'] + + trainer.loss_type = train_params['loss_type'] + trainer.seg_alpha = train_params['seg_alpha'] + + trainer.train(train_data, val_data, args.n_epochs, + dsc_learning_rate=train_params['disc_learning_rate'], + gen_learning_rate=train_params['gen_learning_rate'], + lr_decay=train_params.get('decay_rate', None), + save_freq=train_params.get('save_freq', 10)) diff --git a/setup.py b/setup.py index 1990351..9d068b5 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,11 @@ author='Kameswara Mantha, Ramanakumar Sankar, Lucy Fortson', author_email='manth145@umn.edu, rsankar@umn.edu, lfortson@umn.edu', packages=find_packages(), + entry_points={ + 'console_scripts': [ + 'patchgan_train = patchgan.train:patchgan_train' + ] + }, install_requires=[ 'numpy>=1.21.0,<1.25.2', 'torch>=1.13.0', From d988f231090e952c0482dd636f1fa4f27af25f5c Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 20:12:31 -0500 Subject: [PATCH 13/31] removing training example --- examples/train.py | 89 ----------------------------------------------- 1 file changed, 89 deletions(-) delete mode 100644 examples/train.py diff --git a/examples/train.py b/examples/train.py deleted file mode 100644 index 2daca4a..0000000 --- a/examples/train.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -from torchinfo import summary -from patchgan.unet import UNet -from patchgan.disc import Discriminator -from patchgan.io import COCOStuffDataset -from patchgan.trainer import Trainer -from torch.utils.data import DataLoader -import yaml -import argparse - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - prog='PatchGAN', - description='Train the PatchGAN architecture' - ) - - parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') - parser.add_argument('-b', '--batch_size', default=16, type=int, help='Number of images per batch') - parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') - parser.add_argument('-n', '--n_epochs', required=True, type=int, help='Number of epochs to train the model') - parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') - parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") - - args = parser.parse_args() - - if args.device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - elif args.device in ['cuda', 'cpu']: - device = args.device - - with open(args.config_file, 'r') as infile: - config = yaml.safe_load(infile) - - train_data_paths = config['train_data'] - val_data_paths = config['validation_data'] - - size = config['dataset'].get('size', 256) - augmentation = config['dataset'].get('augmentation', 'randomcrop') - - dataset_kwargs = {} - if config['dataset']['type'] == 'COCOStuff': - Dataset = COCOStuffDataset - in_channels = 3 - labels = config['dataset'].get('labels', [1]) - out_channels = len(labels) - dataset_kwargs['labels'] = labels - else: - raise ValueError(f"{config['dataset']['type']} is unknown!") - - train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - - gen_filts = config['model_params']['gen_filts'] - disc_filts = config['model_params']['disc_filts'] - n_disc_layers = config['model_params']['n_disc_layers'] - activation = config['model_params']['activation'] - use_dropout = config['model_params'].get('use_dropout', True) - final_activation = config['model_params'].get('final_activation', 'sigmoid') - - train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) - val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) - - # create the generator - generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) - - # create the discriminator - discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) - - if args.summary: - summary(generator, [1, in_channels, size, size]) - summary(discriminator, [1, in_channels + out_channels, size, size]) - - checkpoint_path = config.get('checkpoint_path', './checkpoints/') - - trainer = Trainer(generator, discriminator, savefolder=checkpoint_path) - - if config.get('load_last_checkpoint', False): - trainer.load_last_checkpoint() - - train_params = config['train_params'] - - trainer.loss_type = train_params['loss_type'] - trainer.seg_alpha = train_params['seg_alpha'] - - trainer.train(train_data, val_data, args.n_epochs, - dsc_learning_rate=train_params['disc_learning_rate'], - gen_learning_rate=train_params['gen_learning_rate'], - lr_decay=train_params.get('decay_rate', None), - save_freq=train_params.get('save_freq', 10)) From 32707250a569999142419b0df372f7ebd31a4272 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sun, 30 Jul 2023 20:14:03 -0500 Subject: [PATCH 14/31] adding verbosity to dataset import --- patchgan/train.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index db24ac0..a37f3af 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -47,8 +47,15 @@ def patchgan_train(): out_channels = len(labels) dataset_kwargs['labels'] = labels else: - spec = importlib.machinery.SourceFileLoader('', 'io.py') - Dataset = spec.load_module().__getattribute__(config['dataset']['type']) + try: + spec = importlib.machinery.SourceFileLoader('io', 'io.py') + Dataset = spec.load_module().__getattribute__(config['dataset']['type']) + except FileNotFoundError: + print("Make sure io.py is in the working directory!") + raise + except (ImportError, ModuleNotFoundError): + print(f"io.py does not contain {config['dataset']['type']}") + raise in_channels = 4 out_channels = 1 From df43ad00a7a9077c4a5f9dae262381597350ad68 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 18:54:19 -0500 Subject: [PATCH 15/31] removing old dependencies --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 9d068b5..8e0aa22 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,7 @@ 'torch>=1.13.0', 'torchvision>=0.14.0,<=0.15.0', 'tqdm>=4.62.3,<=4.65.0', + 'torchinfo>=1.5.0,', + 'pyyaml', ] ) From 7215cb13f05e25f0acf68fbffd0523708dd8f88b Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 18:54:29 -0500 Subject: [PATCH 16/31] removing FloatingForests datagen --- patchgan/io.py | 74 ++++---------------------------------------------- 1 file changed, 6 insertions(+), 68 deletions(-) diff --git a/patchgan/io.py b/patchgan/io.py index 3deedf4..c49fa4f 100644 --- a/patchgan/io.py +++ b/patchgan/io.py @@ -4,69 +4,7 @@ import glob import os from torchvision.io import read_image, ImageReadMode -from torchvision.transforms import Compose, Resize, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip -from einops import rearrange -import rasterio - - -class FloatingForestsDataset(Dataset): - augmentation = None - - def __init__(self, imgfolder, maskfolder, size=256, augmentation='randomcrop'): - self.images = sorted(glob.glob(os.path.join(imgfolder, "*.tif"))) - self.masks = sorted(glob.glob(os.path.join(maskfolder, "*.tif"))) - self.size = size - - self.image_ids = [int(os.path.basename(image).replace('.tif', '')) for image in self.images] - self.mask_ids = [int(os.path.basename(image).replace('.tif', '')) for image in self.masks] - - assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!" - - if augmentation == 'randomcrop': - self.augmentation = RandomCrop(size=(size, size)) - elif augmentation == 'randomcrop+flip': - self.augmentation = Compose([ - RandomCrop(size=(size, size)), - RandomHorizontalFlip(0.25), - RandomVerticalFlip(0.25) - ]) - - print(f"Loaded {len(self)} images") - - def __len__(self): - return len(self.images) - - def __getitem__(self, index): - image_file = self.images[index] - mask_file = self.masks[index] - - img = rasterio.open(image_file) - img_stacked = np.dstack([img.read(1), img.read(2), img.read(3), img.read(4)]) - - # clean up artifacts in the data - img_stacked[np.abs(img_stacked) > 1.e20] = 0 - img_stacked[np.isnan(img_stacked)] = 0 - - # remove negative signal - img_stacked = img_stacked - np.percentile(img_stacked.flatten(), 2) - - norm = np.nansum(img_stacked, axis=-1, keepdims=True) - img_stacked = img_stacked / (norm + 1.e-3) - - # normalize the image per-pixel - img = np.clip(img_stacked, 0, 1) - img[~np.isfinite(img)] = 0. - - # add the mask so we can crop it - mask = rasterio.open(mask_file).read(1) - mask[mask < 0.] = 0. - data_stacked = np.concatenate((img, np.expand_dims(mask, -1)), axis=-1) - data_stacked = rearrange(torch.Tensor(data_stacked), "h w c -> c h w") - - if self.augmentation is not None: - data_stacked = self.augmentation(data_stacked) - - return data_stacked[:4, :, :], data_stacked[4, :, :].unsqueeze(0) +from torchvision import transforms class COCOStuffDataset(Dataset): @@ -84,12 +22,12 @@ def __init__(self, imgfolder, maskfolder, labels=[1], size=256, augmentation='re assert np.all(self.image_ids == self.mask_ids), "Image IDs and Mask IDs do not match!" if augmentation == 'randomcrop': - self.augmentation = Resize(size=(size, size), antialias=None) + self.augmentation = transforms.Resize(size=(size, size), antialias=None) elif augmentation == 'randomcrop+flip': - self.augmentation = Compose([ - Resize(size=(size, size), antialias=None), - RandomHorizontalFlip(0.25), - RandomVerticalFlip(0.25), + self.augmentation = transforms.Compose([ + transforms.Resize(size=(size, size), antialias=None), + transforms.RandomHorizontalFlip(0.25), + transforms.RandomVerticalFlip(0.25), ]) print(f"Loaded {len(self)} images") From fc8e11ccdb85969f374ae5cdbf3de9f57cc4665e Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 18:54:41 -0500 Subject: [PATCH 17/31] adding transfer learning to training script --- patchgan/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/patchgan/train.py b/patchgan/train.py index a37f3af..875f511 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -88,6 +88,11 @@ def patchgan_train(): if config.get('load_last_checkpoint', False): trainer.load_last_checkpoint() + elif config.get('transfer_learn', {}).get('generator_checkpoint', None) is not None: + gen_checkpoint = config['transfer_learn']['generator_checkpoint'] + dsc_checkpoint = config['transfer_learn']['discriminator_checkpoint'] + generator.load_transfer_data(torch.load(gen_checkpoint, map_location=args.device)) + discriminator.load_transfer_data(torch.load(dsc_checkpoint, map_location=args.device)) train_params = config['train_params'] From 11b75de679cd161e23e261f320942f89a481e66d Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 18:59:53 -0500 Subject: [PATCH 18/31] updating dependencies --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8e0aa22..7226241 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,8 @@ install_requires=[ 'numpy>=1.21.0,<1.25.2', 'torch>=1.13.0', - 'torchvision>=0.14.0,<=0.15.0', - 'tqdm>=4.62.3,<=4.65.0', + 'torchvision>=0.14.0', + 'tqdm>=4.62.3', 'torchinfo>=1.5.0,', 'pyyaml', ] From 4f52d44500462a9d904e8da977b29f15393c5529 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 19:33:44 -0500 Subject: [PATCH 19/31] moving dataset arguments into dataset keyword for training --- patchgan/train.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 875f511..30ee5c3 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -33,28 +33,29 @@ def patchgan_train(): with open(args.config_file, 'r') as infile: config = yaml.safe_load(infile) - train_data_paths = config['train_data'] - val_data_paths = config['validation_data'] + dataset_params = config['dataset'] + train_data_paths = dataset_params['train_data'] + val_data_paths = dataset_params['validation_data'] - size = config['dataset'].get('size', 256) - augmentation = config['dataset'].get('augmentation', 'randomcrop') + size = dataset_params.get('size', 256) + augmentation = dataset_params.get('augmentation', 'randomcrop') dataset_kwargs = {} - if config['dataset']['type'] == 'COCOStuff': + if dataset_params['type'] == 'COCOStuff': Dataset = COCOStuffDataset in_channels = 3 - labels = config['dataset'].get('labels', [1]) + labels = dataset_params.get('labels', [1]) out_channels = len(labels) dataset_kwargs['labels'] = labels else: try: spec = importlib.machinery.SourceFileLoader('io', 'io.py') - Dataset = spec.load_module().__getattribute__(config['dataset']['type']) + Dataset = spec.load_module().__getattribute__(dataset_params['type']) except FileNotFoundError: print("Make sure io.py is in the working directory!") raise except (ImportError, ModuleNotFoundError): - print(f"io.py does not contain {config['dataset']['type']}") + print(f"io.py does not contain {dataset_params['type']}") raise in_channels = 4 out_channels = 1 @@ -62,12 +63,13 @@ def patchgan_train(): train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - gen_filts = config['model_params']['gen_filts'] - disc_filts = config['model_params']['disc_filts'] - n_disc_layers = config['model_params']['n_disc_layers'] - activation = config['model_params']['activation'] - use_dropout = config['model_params'].get('use_dropout', True) - final_activation = config['model_params'].get('final_activation', 'sigmoid') + model_params = config['model_params'] + gen_filts = model_params['gen_filts'] + disc_filts = model_params['disc_filts'] + n_disc_layers = model_params['n_disc_layers'] + activation = model_params['activation'] + use_dropout = model_params.get('use_dropout', True) + final_activation = model_params.get('final_activation', 'sigmoid') train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) From 157d33aca110313d2201f0fa240e2a1951a2b6bd Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Wed, 2 Aug 2023 19:40:04 -0500 Subject: [PATCH 20/31] fixing typo in device --- patchgan/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 30ee5c3..1da6dbd 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -93,8 +93,8 @@ def patchgan_train(): elif config.get('transfer_learn', {}).get('generator_checkpoint', None) is not None: gen_checkpoint = config['transfer_learn']['generator_checkpoint'] dsc_checkpoint = config['transfer_learn']['discriminator_checkpoint'] - generator.load_transfer_data(torch.load(gen_checkpoint, map_location=args.device)) - discriminator.load_transfer_data(torch.load(dsc_checkpoint, map_location=args.device)) + generator.load_transfer_data(torch.load(gen_checkpoint, map_location=device)) + discriminator.load_transfer_data(torch.load(dsc_checkpoint, map_location=device)) train_params = config['train_params'] From c5426bb492517205e7235df6accc84c7ce4e6308 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:07:39 -0500 Subject: [PATCH 21/31] added verbosity to transfer --- patchgan/transfer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/patchgan/transfer.py b/patchgan/transfer.py index 3fedbff..486f2e4 100644 --- a/patchgan/transfer.py +++ b/patchgan/transfer.py @@ -7,9 +7,20 @@ def __init__(self): def load_transfer_data(self, state_dict): own_state = self.state_dict() + count = 0 for name, param in state_dict.items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data if param.shape == own_state[name].data.shape: own_state[name].copy_(param) + count += 1 + + if count > 0: + print(f"Loaded {count} weights out of {len(state_dict)}") + else: + raise InvalidCheckpointError("Could not load transfer weights") + + +class InvalidCheckpointError(Exception): + pass From 888902c2d3987981c93d9aa4845877d9999bb3b5 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:08:18 -0500 Subject: [PATCH 22/31] changed disc activation to tanh --- patchgan/disc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/patchgan/disc.py b/patchgan/disc.py index 66b6772..b6f1ea6 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -26,7 +26,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=False), - nn.LeakyReLU(0.2, True), + nn.Tanh(), norm_layer(ndf * nf_mult) ] @@ -35,7 +35,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False), - nn.LeakyReLU(0.2, True), + nn.Tanh(), norm_layer(ndf * nf_mult) ] From bf251634a0422047c7611b9b077300c85c96c73f Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:08:45 -0500 Subject: [PATCH 23/31] changed weights init to use instancenorm --- patchgan/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patchgan/trainer.py b/patchgan/trainer.py index ce53f30..e6b05ce 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -336,6 +336,6 @@ def init_func(m): # define the initialization function if hasattr(m, 'weight') and (classname.find('Conv')) != -1: torch.nn.init.xavier_uniform_(m.weight.data) # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - elif classname.find('BatchNorm') != -1: + elif classname.find('InstanceNorm') != -1: torch.nn.init.xavier_uniform_(m.weight.data, 1.0) torch.nn.init.constant_(m.bias.data, 0.0) From 0873381b968d57f94e80fdcd3fb08c5b3ed6b843 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:08:56 -0500 Subject: [PATCH 24/31] added more options to training script --- patchgan/train.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/patchgan/train.py b/patchgan/train.py index 1da6dbd..5cd3ff2 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -4,7 +4,7 @@ from patchgan.disc import Discriminator from patchgan.io import COCOStuffDataset from patchgan.trainer import Trainer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split import yaml import importlib.machinery import argparse @@ -34,8 +34,15 @@ def patchgan_train(): config = yaml.safe_load(infile) dataset_params = config['dataset'] - train_data_paths = dataset_params['train_data'] - val_data_paths = dataset_params['validation_data'] + if ('train_data' in dataset_params) and ('validation_data' in dataset_params): + train_data_paths = dataset_params['train_data'] + val_data_paths = dataset_params['validation_data'] + train_val_split = None + elif ('data' in dataset_params) and ('train_val_split' in dataset_params): + data_paths = dataset_params['data'] + train_val_split = dataset_params['train_val_split'] + else: + raise AttributeError("Please provide either the training and validation data paths or a train/val split!") size = dataset_params.get('size', 256) augmentation = dataset_params.get('augmentation', 'randomcrop') @@ -57,11 +64,15 @@ def patchgan_train(): except (ImportError, ModuleNotFoundError): print(f"io.py does not contain {dataset_params['type']}") raise - in_channels = 4 - out_channels = 1 + in_channels = dataset_params.get('in_channels', 3) + out_channels = dataset_params.get('out_channels', 1) - train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) - val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + if train_val_split is None: + train_datagen = Dataset(train_data_paths['images'], train_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + val_datagen = Dataset(val_data_paths['images'], val_data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + else: + datagen = Dataset(data_paths['images'], data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) + train_datagen, val_datagen = random_split(datagen, train_val_split) model_params = config['model_params'] gen_filts = model_params['gen_filts'] @@ -71,8 +82,13 @@ def patchgan_train(): use_dropout = model_params.get('use_dropout', True) final_activation = model_params.get('final_activation', 'sigmoid') - train_data = DataLoader(train_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) - val_data = DataLoader(val_datagen, num_workers=args.dataloader_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True) + dloader_kwargs = {} + if args.dataloader_workers > 0: + dloader_kwargs['num_workers'] = args.dataloader_workers + dloader_kwargs['persistent_workers'] = True + + train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) + val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) # create the generator generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) From ee3962afd54373ad3bba7d698ac9e31e89723fa8 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:09:10 -0500 Subject: [PATCH 25/31] adding inference script --- patchgan/infer.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 6 +- 2 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 patchgan/infer.py diff --git a/patchgan/infer.py b/patchgan/infer.py new file mode 100644 index 0000000..96f21b9 --- /dev/null +++ b/patchgan/infer.py @@ -0,0 +1,171 @@ +import torch +from torchinfo import summary +from patchgan.unet import UNet +from patchgan.disc import Discriminator +from patchgan.io import COCOStuffDataset +import yaml +import tqdm +import os +import numpy as np +import importlib.machinery +import argparse +import matplotlib.pyplot as plt + + +def n_crop(image, size, overlap): + c, height, width = image.shape + + effective_size = int(overlap * size) + + ncropsy = int(np.ceil(height / effective_size)) + ncropsx = int(np.ceil(width / effective_size)) + + crops = torch.zeros((ncropsx * ncropsy, c, size, size), device=image.device) + + for j in range(ncropsy): + for i in range(ncropsx): + starty = j * effective_size + startx = i * effective_size + + starty -= max([starty + size - height, 0]) + startx -= max([startx + size - width, 0]) + + crops[j * ncropsy + i, :] = image[:, starty:starty + size, startx:startx + size] + + return crops + + +def build_mask(masks, crop_size, image_size, threshold, overlap): + n, c, height, width = masks.shape + image_height, image_width = image_size + mask = np.zeros((c, *image_size)) + count = np.zeros((c, *image_size)) + + effective_size = int(overlap * crop_size) + + ncropsy = int(np.ceil(image_height / effective_size)) + ncropsx = int(np.ceil(image_width / effective_size)) + + for j in range(ncropsy): + for i in range(ncropsx): + starty = j * effective_size + startx = i * effective_size + starty -= max([starty + crop_size - image_height, 0]) + startx -= max([startx + crop_size - image_width, 0]) + endy = starty + crop_size + endx = startx + crop_size + + mask[:, starty:endy, startx:endx] += masks[j * ncropsy + i, :] + count[:, starty:endy, startx:endx] += 1 + mask = mask / count + + if threshold > 0: + mask[mask >= threshold] = 1 + mask[mask < threshold] = 0 + + if c > 1: + return np.argmax(mask, axis=0) + else: + return mask[0] + + +def patchgan_infer(): + parser = argparse.ArgumentParser( + prog='PatchGAN', + description='Train the PatchGAN architecture' + ) + + parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') + parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') + parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') + parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") + + args = parser.parse_args() + + if args.device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + elif args.device in ['cuda', 'cpu']: + device = args.device + + with open(args.config_file, 'r') as infile: + config = yaml.safe_load(infile) + + dataset_params = config['dataset'] + dataset_path = dataset_params['dataset_path'] + + size = dataset_params.get('size', 256) + + dataset_kwargs = {} + if dataset_params['type'] == 'COCOStuff': + Dataset = COCOStuffDataset + in_channels = 3 + labels = dataset_params.get('labels', [1]) + out_channels = len(labels) + dataset_kwargs['labels'] = labels + else: + try: + spec = importlib.machinery.SourceFileLoader('io', 'io.py') + Dataset = spec.load_module().__getattribute__(dataset_params['type']) + except FileNotFoundError: + print("Make sure io.py is in the working directory!") + raise + except (ImportError, ModuleNotFoundError): + print(f"io.py does not contain {dataset_params['type']}") + raise + in_channels = dataset_params.get('in_channels', 3) + out_channels = dataset_params.get('out_channels', 1) + + assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename),\ + f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index" + + datagen = Dataset(dataset_path, size=size, **dataset_kwargs) + + model_params = config['model_params'] + gen_filts = model_params['gen_filts'] + disc_filts = model_params['disc_filts'] + n_disc_layers = model_params['n_disc_layers'] + activation = model_params['activation'] + use_dropout = model_params.get('use_dropout', True) + final_activation = model_params.get('final_activation', 'sigmoid') + + # create the generator + generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) + + # create the discriminator + discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) + + if args.summary: + summary(generator, [1, in_channels, size, size]) + summary(discriminator, [1, in_channels + out_channels, size, size]) + + checkpoint_paths = config.get('checkpoint_paths') + gen_checkpoint = checkpoint_paths['generator'] + dsc_checkpoint = checkpoint_paths['discriminator'] + + infer_params = config.get('infer_params', {}) + output_path = infer_params.get('output_path', 'predictions/') + + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"Created folder {output_path}") + + generator.eval() + discriminator.eval() + + generator.load_state_dict(torch.load(gen_checkpoint, map_location=device)) + discriminator.load_state_dict(torch.load(dsc_checkpoint, map_location=device)) + + threshold = infer_params.get('threshold', 0) + overlap = infer_params.get('overlap', 0.9) + + for i, data in enumerate(tqdm.tqdm(datagen, desc='Predicting', dynamic_ncols=True, ascii=True)): + imgs = n_crop(data, size, overlap) + out_fname, _ = os.path.splitext(datagen.get_filename(i)) + + with torch.no_grad(): + img_tensor = torch.Tensor(imgs).to(device) + masks = generator(img_tensor).cpu().numpy() + + mask = build_mask(masks, datagen.size, data.shape[1:], threshold, overlap) + + plt.imsave(os.path.join(output_path, out_fname + ".png"), mask, cmap='gray') diff --git a/setup.py b/setup.py index 7226241..99f8a84 100644 --- a/setup.py +++ b/setup.py @@ -26,15 +26,19 @@ packages=find_packages(), entry_points={ 'console_scripts': [ - 'patchgan_train = patchgan.train:patchgan_train' + 'patchgan_train = patchgan.train:patchgan_train', + 'patchgan_infer = patchgan.infer:patchgan_infer' ] }, install_requires=[ 'numpy>=1.21.0,<1.25.2', 'torch>=1.13.0', + 'matplotlib>3.5.0', 'torchvision>=0.14.0', 'tqdm>=4.62.3', 'torchinfo>=1.5.0,', 'pyyaml', + 'patchify', + 'einops' ] ) From a0514614c70d22b4466693dad638497d3f48f276 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 01:11:48 -0500 Subject: [PATCH 26/31] fixing flake issues --- patchgan/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index 96f21b9..35abae0 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -115,7 +115,7 @@ def patchgan_infer(): in_channels = dataset_params.get('in_channels', 3) out_channels = dataset_params.get('out_channels', 1) - assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename),\ + assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename), \ f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index" datagen = Dataset(dataset_path, size=size, **dataset_kwargs) From 11d8f01794f706732648b80312e5485bfad1b2d9 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 8 Aug 2023 10:54:02 -0500 Subject: [PATCH 27/31] making checkpoint_paths a necessary option for inference --- patchgan/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index 35abae0..a1c90fd 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -138,7 +138,7 @@ def patchgan_infer(): summary(generator, [1, in_channels, size, size]) summary(discriminator, [1, in_channels + out_channels, size, size]) - checkpoint_paths = config.get('checkpoint_paths') + checkpoint_paths = config['checkpoint_paths'] gen_checkpoint = checkpoint_paths['generator'] dsc_checkpoint = checkpoint_paths['discriminator'] From ad3a142cfbca483ed3211db980b6ef75b3e286cf Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sat, 19 Aug 2023 11:03:44 -0500 Subject: [PATCH 28/31] removing size param from datagen --- patchgan/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index a1c90fd..eb7e4ff 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -118,7 +118,7 @@ def patchgan_infer(): assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename), \ f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index" - datagen = Dataset(dataset_path, size=size, **dataset_kwargs) + datagen = Dataset(dataset_path, **dataset_kwargs) model_params = config['model_params'] gen_filts = model_params['gen_filts'] @@ -166,6 +166,6 @@ def patchgan_infer(): img_tensor = torch.Tensor(imgs).to(device) masks = generator(img_tensor).cpu().numpy() - mask = build_mask(masks, datagen.size, data.shape[1:], threshold, overlap) + mask = build_mask(masks, size, data.shape[1:], threshold, overlap) plt.imsave(os.path.join(output_path, out_fname + ".png"), mask, cmap='gray') From acce01e9d05a1091c2981c16c198c588fc327678 Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Sat, 19 Aug 2023 11:40:17 -0500 Subject: [PATCH 29/31] removing dropouts from the inference --- patchgan/infer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index eb7e4ff..1647c98 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -125,11 +125,10 @@ def patchgan_infer(): disc_filts = model_params['disc_filts'] n_disc_layers = model_params['n_disc_layers'] activation = model_params['activation'] - use_dropout = model_params.get('use_dropout', True) final_activation = model_params.get('final_activation', 'sigmoid') # create the generator - generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) + generator = UNet(in_channels, out_channels, gen_filts, activation=activation, final_act=final_activation).to(device) # create the discriminator discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) From 1fa482c3948769a7d4f1fc4e5352371e4d99c67d Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 22 Aug 2023 14:33:40 -0500 Subject: [PATCH 30/31] fixing device switching bug. creating a Dataset.save_mask requirement to save mask files --- patchgan/infer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/patchgan/infer.py b/patchgan/infer.py index 1647c98..322de97 100644 --- a/patchgan/infer.py +++ b/patchgan/infer.py @@ -9,7 +9,6 @@ import numpy as np import importlib.machinery import argparse -import matplotlib.pyplot as plt def n_crop(image, size, overlap): @@ -87,6 +86,8 @@ def patchgan_infer(): elif args.device in ['cuda', 'cpu']: device = args.device + print(f"Running with {device}") + with open(args.config_file, 'r') as infile: config = yaml.safe_load(infile) @@ -134,8 +135,8 @@ def patchgan_infer(): discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) if args.summary: - summary(generator, [1, in_channels, size, size]) - summary(discriminator, [1, in_channels + out_channels, size, size]) + summary(generator, [1, in_channels, size, size], device=device) + summary(discriminator, [1, in_channels + out_channels, size, size], device=device) checkpoint_paths = config['checkpoint_paths'] gen_checkpoint = checkpoint_paths['generator'] @@ -167,4 +168,4 @@ def patchgan_infer(): mask = build_mask(masks, size, data.shape[1:], threshold, overlap) - plt.imsave(os.path.join(output_path, out_fname + ".png"), mask, cmap='gray') + Dataset.save_mask(mask, output_path, out_fname) From 588b2c44d10036e7b99f01254e6d13cae5953aaa Mon Sep 17 00:00:00 2001 From: Ramanakumar Sankar Date: Tue, 22 Aug 2023 14:34:08 -0500 Subject: [PATCH 31/31] adding version bump --- patchgan/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patchgan/version.py b/patchgan/version.py index 11d27f8..b650ceb 100644 --- a/patchgan/version.py +++ b/patchgan/version.py @@ -1 +1 @@ -__version__ = '0.1' +__version__ = '0.2'