-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from ramanakumars/unifying_arch
Updating architecture
- Loading branch information
Showing
17 changed files
with
609 additions
and
573 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,30 @@ | ||
# patchGAN | ||
patchGAN model for image segmentation | ||
|
||
[![PyPI version](https://badge.fury.io/py/patchGAN.svg)](https://badge.fury.io/py/patchGAN) | ||
|
||
UNet-based GAN model for image segmentation using a patch-wise discriminator. | ||
Based on the [pix2pix](https://phillipi.github.io/pix2pix/) model. | ||
|
||
## Installation | ||
|
||
Install the package with pip: | ||
``` | ||
pip install patchgan | ||
``` | ||
|
||
Upgrading existing install: | ||
``` | ||
pip install -U patchgan | ||
``` | ||
|
||
Get the current development branch: | ||
``` | ||
pip install -U git+https://github.com/ramanakumars/patchGAN.git | ||
``` | ||
|
||
## Training | ||
You can train the patchGAN model with a config file and the `patchgan_train` command: | ||
``` | ||
patchgan_train --config_file train_coco.yaml --n_epochs 100 --batch_size 16 | ||
``` | ||
See `examples/train_coco.yaml` for the corresponding config for the COCO stuff dataset. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
dataset: | ||
type: COCOStuff | ||
augmentation: randomcrop+flip | ||
size: 256 | ||
train_data: | ||
images: /d1/rsankar/data/COCOstuff/train2017 | ||
masks: /d1/rsankar/data/COCOstuff/train2017 | ||
labels: [1, 2, 3, 4, 5, 6, 7] | ||
validation_data: | ||
images: /d1/rsankar/data/COCOstuff/val2017 | ||
masks: /d1/rsankar/data/COCOstuff/val2017 | ||
labels: [1, 2, 3, 4, 5, 6, 7] | ||
model_params: | ||
gen_filts: 32 | ||
disc_filts: 16 | ||
activation: relu | ||
use_dropout: True | ||
final_activation: sigmoid | ||
n_disc_layers: 5 | ||
checkpoint_path: ./checkpoints/checkpoint-COCO/ | ||
load_last_checkpoint: True | ||
train_params: | ||
loss_type: weighted_bce | ||
seg_alpha: 200 | ||
gen_learning_rate: 1.e-3 | ||
disc_learning_rate: 1.e-3 | ||
decay_rate: 0.95 | ||
save_freq: 5 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from .unet import * | ||
from .io import * | ||
from .losses import * | ||
from .utils import * | ||
from .unet import UNet | ||
from .disc import Discriminator | ||
from .trainer import Trainer | ||
from .version import __version__ | ||
|
||
__all__ = [ | ||
'UNet', 'Discriminator', 'Trainer', '__version__' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from torch import nn | ||
from .transfer import Transferable | ||
|
||
|
||
class Discriminator(nn.Module, Transferable): | ||
"""Defines a PatchGAN discriminator""" | ||
|
||
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): | ||
"""Construct a PatchGAN discriminator | ||
Parameters: | ||
input_nc (int) -- the number of channels in input images | ||
ndf (int) -- the number of filters in the last conv layer | ||
n_layers (int) -- the number of conv layers in the discriminator | ||
norm_layer -- normalization layer | ||
""" | ||
super(Discriminator, self).__init__() | ||
kw = 4 | ||
padw = 1 | ||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, | ||
stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | ||
nf_mult = 1 | ||
nf_mult_prev = 1 | ||
for n in range(1, n_layers): # gradually increase the number of filters | ||
nf_mult_prev = nf_mult | ||
nf_mult = min(2 ** n, 8) | ||
sequence += [ | ||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, | ||
kernel_size=kw, stride=2, padding=padw, bias=False), | ||
nn.Tanh(), | ||
norm_layer(ndf * nf_mult) | ||
] | ||
|
||
nf_mult_prev = nf_mult | ||
nf_mult = min(2 ** n_layers, 8) | ||
sequence += [ | ||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, | ||
kernel_size=kw, stride=1, padding=padw, bias=False), | ||
nn.Tanh(), | ||
norm_layer(ndf * nf_mult) | ||
] | ||
|
||
# output 1 channel prediction map | ||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, | ||
stride=1, padding=padw), nn.Sigmoid()] | ||
self.model = nn.Sequential(*sequence) | ||
|
||
def forward(self, input): | ||
"""Standard forward.""" | ||
return self.model(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import torch | ||
from torchinfo import summary | ||
from patchgan.unet import UNet | ||
from patchgan.disc import Discriminator | ||
from patchgan.io import COCOStuffDataset | ||
import yaml | ||
import tqdm | ||
import os | ||
import numpy as np | ||
import importlib.machinery | ||
import argparse | ||
|
||
|
||
def n_crop(image, size, overlap): | ||
c, height, width = image.shape | ||
|
||
effective_size = int(overlap * size) | ||
|
||
ncropsy = int(np.ceil(height / effective_size)) | ||
ncropsx = int(np.ceil(width / effective_size)) | ||
|
||
crops = torch.zeros((ncropsx * ncropsy, c, size, size), device=image.device) | ||
|
||
for j in range(ncropsy): | ||
for i in range(ncropsx): | ||
starty = j * effective_size | ||
startx = i * effective_size | ||
|
||
starty -= max([starty + size - height, 0]) | ||
startx -= max([startx + size - width, 0]) | ||
|
||
crops[j * ncropsy + i, :] = image[:, starty:starty + size, startx:startx + size] | ||
|
||
return crops | ||
|
||
|
||
def build_mask(masks, crop_size, image_size, threshold, overlap): | ||
n, c, height, width = masks.shape | ||
image_height, image_width = image_size | ||
mask = np.zeros((c, *image_size)) | ||
count = np.zeros((c, *image_size)) | ||
|
||
effective_size = int(overlap * crop_size) | ||
|
||
ncropsy = int(np.ceil(image_height / effective_size)) | ||
ncropsx = int(np.ceil(image_width / effective_size)) | ||
|
||
for j in range(ncropsy): | ||
for i in range(ncropsx): | ||
starty = j * effective_size | ||
startx = i * effective_size | ||
starty -= max([starty + crop_size - image_height, 0]) | ||
startx -= max([startx + crop_size - image_width, 0]) | ||
endy = starty + crop_size | ||
endx = startx + crop_size | ||
|
||
mask[:, starty:endy, startx:endx] += masks[j * ncropsy + i, :] | ||
count[:, starty:endy, startx:endx] += 1 | ||
mask = mask / count | ||
|
||
if threshold > 0: | ||
mask[mask >= threshold] = 1 | ||
mask[mask < threshold] = 0 | ||
|
||
if c > 1: | ||
return np.argmax(mask, axis=0) | ||
else: | ||
return mask[0] | ||
|
||
|
||
def patchgan_infer(): | ||
parser = argparse.ArgumentParser( | ||
prog='PatchGAN', | ||
description='Train the PatchGAN architecture' | ||
) | ||
|
||
parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file') | ||
parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)') | ||
parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)') | ||
parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.device == 'auto': | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
elif args.device in ['cuda', 'cpu']: | ||
device = args.device | ||
|
||
print(f"Running with {device}") | ||
|
||
with open(args.config_file, 'r') as infile: | ||
config = yaml.safe_load(infile) | ||
|
||
dataset_params = config['dataset'] | ||
dataset_path = dataset_params['dataset_path'] | ||
|
||
size = dataset_params.get('size', 256) | ||
|
||
dataset_kwargs = {} | ||
if dataset_params['type'] == 'COCOStuff': | ||
Dataset = COCOStuffDataset | ||
in_channels = 3 | ||
labels = dataset_params.get('labels', [1]) | ||
out_channels = len(labels) | ||
dataset_kwargs['labels'] = labels | ||
else: | ||
try: | ||
spec = importlib.machinery.SourceFileLoader('io', 'io.py') | ||
Dataset = spec.load_module().__getattribute__(dataset_params['type']) | ||
except FileNotFoundError: | ||
print("Make sure io.py is in the working directory!") | ||
raise | ||
except (ImportError, ModuleNotFoundError): | ||
print(f"io.py does not contain {dataset_params['type']}") | ||
raise | ||
in_channels = dataset_params.get('in_channels', 3) | ||
out_channels = dataset_params.get('out_channels', 1) | ||
|
||
assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename), \ | ||
f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index" | ||
|
||
datagen = Dataset(dataset_path, **dataset_kwargs) | ||
|
||
model_params = config['model_params'] | ||
gen_filts = model_params['gen_filts'] | ||
disc_filts = model_params['disc_filts'] | ||
n_disc_layers = model_params['n_disc_layers'] | ||
activation = model_params['activation'] | ||
final_activation = model_params.get('final_activation', 'sigmoid') | ||
|
||
# create the generator | ||
generator = UNet(in_channels, out_channels, gen_filts, activation=activation, final_act=final_activation).to(device) | ||
|
||
# create the discriminator | ||
discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) | ||
|
||
if args.summary: | ||
summary(generator, [1, in_channels, size, size], device=device) | ||
summary(discriminator, [1, in_channels + out_channels, size, size], device=device) | ||
|
||
checkpoint_paths = config['checkpoint_paths'] | ||
gen_checkpoint = checkpoint_paths['generator'] | ||
dsc_checkpoint = checkpoint_paths['discriminator'] | ||
|
||
infer_params = config.get('infer_params', {}) | ||
output_path = infer_params.get('output_path', 'predictions/') | ||
|
||
if not os.path.exists(output_path): | ||
os.makedirs(output_path) | ||
print(f"Created folder {output_path}") | ||
|
||
generator.eval() | ||
discriminator.eval() | ||
|
||
generator.load_state_dict(torch.load(gen_checkpoint, map_location=device)) | ||
discriminator.load_state_dict(torch.load(dsc_checkpoint, map_location=device)) | ||
|
||
threshold = infer_params.get('threshold', 0) | ||
overlap = infer_params.get('overlap', 0.9) | ||
|
||
for i, data in enumerate(tqdm.tqdm(datagen, desc='Predicting', dynamic_ncols=True, ascii=True)): | ||
imgs = n_crop(data, size, overlap) | ||
out_fname, _ = os.path.splitext(datagen.get_filename(i)) | ||
|
||
with torch.no_grad(): | ||
img_tensor = torch.Tensor(imgs).to(device) | ||
masks = generator(img_tensor).cpu().numpy() | ||
|
||
mask = build_mask(masks, size, data.shape[1:], threshold, overlap) | ||
|
||
Dataset.save_mask(mask, output_path, out_fname) |
Oops, something went wrong.