Skip to content

Commit

Permalink
Merge pull request #3 from ramanakumars/unifying_arch
Browse files Browse the repository at this point in the history
Updating architecture
  • Loading branch information
ramanakumars authored Aug 22, 2023
2 parents cca0caa + 588b2c4 commit 22d3e50
Show file tree
Hide file tree
Showing 17 changed files with 609 additions and 573 deletions.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,30 @@
# patchGAN
patchGAN model for image segmentation

[![PyPI version](https://badge.fury.io/py/patchGAN.svg)](https://badge.fury.io/py/patchGAN)

UNet-based GAN model for image segmentation using a patch-wise discriminator.
Based on the [pix2pix](https://phillipi.github.io/pix2pix/) model.

## Installation

Install the package with pip:
```
pip install patchgan
```

Upgrading existing install:
```
pip install -U patchgan
```

Get the current development branch:
```
pip install -U git+https://github.com/ramanakumars/patchGAN.git
```

## Training
You can train the patchGAN model with a config file and the `patchgan_train` command:
```
patchgan_train --config_file train_coco.yaml --n_epochs 100 --batch_size 16
```
See `examples/train_coco.yaml` for the corresponding config for the COCO stuff dataset.
29 changes: 29 additions & 0 deletions examples/train_coco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
dataset:
type: COCOStuff
augmentation: randomcrop+flip
size: 256
train_data:
images: /d1/rsankar/data/COCOstuff/train2017
masks: /d1/rsankar/data/COCOstuff/train2017
labels: [1, 2, 3, 4, 5, 6, 7]
validation_data:
images: /d1/rsankar/data/COCOstuff/val2017
masks: /d1/rsankar/data/COCOstuff/val2017
labels: [1, 2, 3, 4, 5, 6, 7]
model_params:
gen_filts: 32
disc_filts: 16
activation: relu
use_dropout: True
final_activation: sigmoid
n_disc_layers: 5
checkpoint_path: ./checkpoints/checkpoint-COCO/
load_last_checkpoint: True
train_params:
loss_type: weighted_bce
seg_alpha: 200
gen_learning_rate: 1.e-3
disc_learning_rate: 1.e-3
decay_rate: 0.95
save_freq: 5

11 changes: 7 additions & 4 deletions patchgan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .unet import *
from .io import *
from .losses import *
from .utils import *
from .unet import UNet
from .disc import Discriminator
from .trainer import Trainer
from .version import __version__

__all__ = [
'UNet', 'Discriminator', 'Trainer', '__version__'
]
49 changes: 49 additions & 0 deletions patchgan/disc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from torch import nn
from .transfer import Transferable


class Discriminator(nn.Module, Transferable):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(Discriminator, self).__init__()
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]

# output 1 channel prediction map
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw,
stride=1, padding=padw), nn.Sigmoid()]
self.model = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.model(input)
171 changes: 171 additions & 0 deletions patchgan/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import torch
from torchinfo import summary
from patchgan.unet import UNet
from patchgan.disc import Discriminator
from patchgan.io import COCOStuffDataset
import yaml
import tqdm
import os
import numpy as np
import importlib.machinery
import argparse


def n_crop(image, size, overlap):
c, height, width = image.shape

effective_size = int(overlap * size)

ncropsy = int(np.ceil(height / effective_size))
ncropsx = int(np.ceil(width / effective_size))

crops = torch.zeros((ncropsx * ncropsy, c, size, size), device=image.device)

for j in range(ncropsy):
for i in range(ncropsx):
starty = j * effective_size
startx = i * effective_size

starty -= max([starty + size - height, 0])
startx -= max([startx + size - width, 0])

crops[j * ncropsy + i, :] = image[:, starty:starty + size, startx:startx + size]

return crops


def build_mask(masks, crop_size, image_size, threshold, overlap):
n, c, height, width = masks.shape
image_height, image_width = image_size
mask = np.zeros((c, *image_size))
count = np.zeros((c, *image_size))

effective_size = int(overlap * crop_size)

ncropsy = int(np.ceil(image_height / effective_size))
ncropsx = int(np.ceil(image_width / effective_size))

for j in range(ncropsy):
for i in range(ncropsx):
starty = j * effective_size
startx = i * effective_size
starty -= max([starty + crop_size - image_height, 0])
startx -= max([startx + crop_size - image_width, 0])
endy = starty + crop_size
endx = startx + crop_size

mask[:, starty:endy, startx:endx] += masks[j * ncropsy + i, :]
count[:, starty:endy, startx:endx] += 1
mask = mask / count

if threshold > 0:
mask[mask >= threshold] = 1
mask[mask < threshold] = 0

if c > 1:
return np.argmax(mask, axis=0)
else:
return mask[0]


def patchgan_infer():
parser = argparse.ArgumentParser(
prog='PatchGAN',
description='Train the PatchGAN architecture'
)

parser.add_argument('-c', '--config_file', required=True, type=str, help='Location of the config YAML file')
parser.add_argument('--dataloader_workers', default=4, type=int, help='Number of workers to use with dataloader (set to 0 to disable multithreading)')
parser.add_argument('-d', '--device', default='auto', help='Device to use to train the model (CUDA=GPU)')
parser.add_argument('--summary', default=True, action='store_true', help="Print summary of the models")

args = parser.parse_args()

if args.device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
elif args.device in ['cuda', 'cpu']:
device = args.device

print(f"Running with {device}")

with open(args.config_file, 'r') as infile:
config = yaml.safe_load(infile)

dataset_params = config['dataset']
dataset_path = dataset_params['dataset_path']

size = dataset_params.get('size', 256)

dataset_kwargs = {}
if dataset_params['type'] == 'COCOStuff':
Dataset = COCOStuffDataset
in_channels = 3
labels = dataset_params.get('labels', [1])
out_channels = len(labels)
dataset_kwargs['labels'] = labels
else:
try:
spec = importlib.machinery.SourceFileLoader('io', 'io.py')
Dataset = spec.load_module().__getattribute__(dataset_params['type'])
except FileNotFoundError:
print("Make sure io.py is in the working directory!")
raise
except (ImportError, ModuleNotFoundError):
print(f"io.py does not contain {dataset_params['type']}")
raise
in_channels = dataset_params.get('in_channels', 3)
out_channels = dataset_params.get('out_channels', 1)

assert hasattr(Dataset, 'get_filename') and callable(Dataset.get_filename), \
f"Dataset class {Dataset.__name__} must have the get_filename method which returns the image filename for a given index"

datagen = Dataset(dataset_path, **dataset_kwargs)

model_params = config['model_params']
gen_filts = model_params['gen_filts']
disc_filts = model_params['disc_filts']
n_disc_layers = model_params['n_disc_layers']
activation = model_params['activation']
final_activation = model_params.get('final_activation', 'sigmoid')

# create the generator
generator = UNet(in_channels, out_channels, gen_filts, activation=activation, final_act=final_activation).to(device)

# create the discriminator
discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device)

if args.summary:
summary(generator, [1, in_channels, size, size], device=device)
summary(discriminator, [1, in_channels + out_channels, size, size], device=device)

checkpoint_paths = config['checkpoint_paths']
gen_checkpoint = checkpoint_paths['generator']
dsc_checkpoint = checkpoint_paths['discriminator']

infer_params = config.get('infer_params', {})
output_path = infer_params.get('output_path', 'predictions/')

if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"Created folder {output_path}")

generator.eval()
discriminator.eval()

generator.load_state_dict(torch.load(gen_checkpoint, map_location=device))
discriminator.load_state_dict(torch.load(dsc_checkpoint, map_location=device))

threshold = infer_params.get('threshold', 0)
overlap = infer_params.get('overlap', 0.9)

for i, data in enumerate(tqdm.tqdm(datagen, desc='Predicting', dynamic_ncols=True, ascii=True)):
imgs = n_crop(data, size, overlap)
out_fname, _ = os.path.splitext(datagen.get_filename(i))

with torch.no_grad():
img_tensor = torch.Tensor(imgs).to(device)
masks = generator(img_tensor).cpu().numpy()

mask = build_mask(masks, size, data.shape[1:], threshold, overlap)

Dataset.save_mask(mask, output_path, out_fname)
Loading

0 comments on commit 22d3e50

Please sign in to comment.