From df13d4ef4c4dd4c164de590aefc83c3c73e9b80b Mon Sep 17 00:00:00 2001 From: Siddhesh Thakur Date: Thu, 29 Apr 2021 03:38:05 +0530 Subject: [PATCH 01/13] single run works --- BrainMaGe/tester/test_single_run.py | 57 ++++++++-- brain_mage_single_run | 171 ++++++++++++++-------------- setup.py | 2 +- 3 files changed, 133 insertions(+), 97 deletions(-) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index 5b3ff03..ecc2653 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -10,7 +10,6 @@ import os import sys import time -import pandas as pd import torch import nibabel as nib import tqdm @@ -20,8 +19,12 @@ from scipy.ndimage.morphology import binary_fill_holes from BrainMaGe.models.networks import fetch_model from BrainMaGe.utils import csv_creator_adv -from BrainMaGe.utils.utils_test import pad_image, process_image, interpolate_image,\ - padder_and_cropper +from BrainMaGe.utils.utils_test import ( + pad_image, + process_image, + interpolate_image, + padder_and_cropper, +) def postprocess_prediction(seg): @@ -32,7 +35,8 @@ def postprocess_prediction(seg): seg[lbls != largest_region] = 0 return seg -def infer_single_ma(hparams): + +def infer_single_ma(input_path, output_path, weights, mask_path=None, device='cpu'): start = time.asctime() startstamp = time.time() print("\nHostname :" + str(os.getenv("HOSTNAME"))) @@ -40,8 +44,45 @@ def infer_single_ma(hparams): print("\nStart Stamp:" + str(startstamp)) sys.stdout.flush() print("Generating Test csv") - if not os.path.exists(os.path.join(hparams.results_dir)): - os.mkdir(hparams.results_dir) - temp_dir = os.path.join(hparams.results_dir, 'Temp') - subjects = hparams.subjects \ No newline at end of file + model = fetch_model( + modelname="resunet", num_channels=1, num_classes=2, num_filters=16 + ) + checkpoint = torch.load(weights) + model.load_state_dict(checkpoint["model_state_dict"]) + + if device != "cpu": + model.cuda() + model.eval() + + patient_nib = nib.load(input_path) + image = patient_nib.get_fdata() + old_shape = patient_nib.shape + image = process_image(image) + image = resize( + image, (128, 128, 128), order=3, mode="edge", cval=0, anti_aliasing=False + ) + image = image[np.newaxis, np.newaxis, ...] + image = torch.FloatTensor(image) + if device != "cpu": + image = image.cuda() + with torch.no_grad(): + output = model(image) + output = output.cpu().numpy()[0][0] + to_save = interpolate_image(output, patient_nib.shape) + to_save[to_save >= 0.9] = 1 + to_save[to_save < 0.9] = 0 + to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine) + nib.save(to_save_nib, os.path.join(output_path)) + + print("Done with running the model.") + + if mask_path is not None: + print("You chose to save the brain. We are now saving it with the masks.") + brain_data = image_data + brain_data[to_save == 0] = 0 + to_save_brain = nib.Nifti1Image(brain_data, image.affine) + nib.save(to_save_brain, os.path.join(mask_path)) + + print("Thank you for using BrainMaGe") + print("*" * 60) diff --git a/brain_mage_single_run b/brain_mage_single_run index 78b0aab..f58fc1f 100755 --- a/brain_mage_single_run +++ b/brain_mage_single_run @@ -9,107 +9,102 @@ Created on Sat May 30 01:05:59 2020 from __future__ import absolute_import, print_function, division import argparse import os -import pandas as pd from BrainMaGe.tester import test_single_run import pkg_resources -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='BrainMaGe', formatter_class=argparse.RawTextHelpFormatter, - description='\nThis code was implemented for Deep Learning '+\ - 'based training and inference of 3D-U-Net,\n3D-Res-U-Net models for '+\ - 'Brain Extraction a.k.a Skull Stripping in biomedical NIfTI volumes.\n'+\ - 'The project is hosted at: https://github.com/CBICA/BrainMaGe * \n'+\ - 'See the documentation for details on its use.\n'+\ - 'If you are using this tool, please cite out paper.' - 'This software accompanies the research presented in:\n'+\ - 'Thakur et al., \'Brain Extraction on MRI Scans in Presence of Diffuse\n'+\ - 'Glioma:Multi-institutional Performance Evaluation of Deep Learning Methods'+\ - 'and Robust Modality-Agnostic Training\'.\n'+\ - 'DOI: 10.1016/j.neuroimage.2020.117081\n' +\ - 'We hope our work helps you in your endeavours.\n'+ '\n'\ - 'Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n'\ - 'For questions and feedback contact: software@cbica.upenn.edu') - - parser.add_argument('-i', '--input', - help='Should be either a file path of a modality or an input folder. - 'If folder is passed all files ending with .nii.gz ' - 'within that folder will be Skull Stripped.', - required=True, type=str) - - parser.add_argument('-o', '--output', - help='Should be either a filename or a folder.\n'+\ - 'In the case of single file or a folder for input', - required=False, type=str) - - parser.add_argument('-dev', default='0', dest='device', type=str, - help='used to set on which device the prediction will run.\n'+ - 'Must be either int or str. Use int for GPU id or\n'+ - '\'cpu\' to run on CPU. Avoid training on CPU. \n'+ - 'Default for selecting first GPU is set to -dev 0\n', - required=False) - - parser.add_argument('-load', default=None, dest='load', type=str, - help='If the location of the weight file is passed, the internal methods\n'+\ - 'are overridden to apply these weights to the model. We warn against\n'+\ - 'the usage of this unless you know what you are passing. C') - - parser.add_argument('-v', '--version', action='version', - version=pkg_resources.require("BrainMaGe")[0].version+'\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.', help="Show program's version number and exit.") - - parser.add_argument('-save_brain', default=1, type=int, required=False, dest='save_brain', - help='if set to 0 the segmentation mask will be only produced and\n'+\ - 'and the mask will not be applied on the input image to produce\n'+\ - ' a brain. This step is to be only applied if you trust this\n'+\ - 'software and do not feel the need for Manual QC. This will save\n'+\ - ' you some time. This is useless for training though.') +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="BrainMaGe", + formatter_class=argparse.RawTextHelpFormatter, + description="\nThis code was implemented for Deep Learning " + + "based training and inference of 3D-U-Net,\n3D-Res-U-Net models for " + + "Brain Extraction a.k.a Skull Stripping in biomedical NIfTI volumes.\n" + + "The project is hosted at: https://github.com/CBICA/BrainMaGe * \n" + + "See the documentation for details on its use.\n" + + "If you are using this tool, please cite out paper." + "This software accompanies the research presented in:\n" + + "Thakur et al., 'Brain Extraction on MRI Scans in Presence of Diffuse\n" + + "Glioma:Multi-institutional Performance Evaluation of Deep Learning Methods" + + "and Robust Modality-Agnostic Training'.\n" + + "DOI: 10.1016/j.neuroimage.2020.117081\n" + + "We hope our work helps you in your endeavours.\n" + + "\n" + "Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n" + "For questions and feedback contact: software@cbica.upenn.edu", + ) + + parser.add_argument( + "-i", + "--input", + dest="input_path", + help="Should be either a file path of a modality or an input folder.\n" + + "If folder is passed all files ending with .nii.gz " + + "within that folder will be Skull Stripped.", + required=True, + type=str, + ) + + parser.add_argument( + "-o", + "--output", + dest="output_path", + help="Should be either a filename or a folder.\n" + + "In the case of single file or a folder for input", + required=True, + type=str, + ) + + parser.add_argument( + "-m", + "--mask_path", + dest="mask_path", + help="Should be either a filename or a folder.\n" + + "In the case of single file or a folder for input", + required=False, + default=None, + type=str, + ) + + parser.add_argument( + "-dev", + default="0", + dest="device", + type=str, + help="used to set on which device the prediction will run.\n" + + "Must be either int or str. Use int for GPU id or\n" + + "'cpu' to run on CPU. Avoid training on CPU. \n" + + "Default for selecting first GPU is set to -dev 0\n", + required=False, + ) + + parser.add_argument( + "-v", + "--version", + action="version", + version=pkg_resources.require("BrainMaGe")[0].version + + "\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.", + help="Show program's version number and exit.", + ) args = parser.parse_args() + + input_path = args.input_path + output_path = args.output_path + mask_path = args.mask_path DEVICE = args.device - # If weights are given in params, then set weights to given params - # else set weights to None - if args.load is not None: - weights = os.path.abspath(args.load) - else: - weights = None - - if weights is not None: - _, ext = os.path.splitext(weights) - if os.path.exists(weights): - pass: - else: - raise ValueError("The weights file path you passed does not exist. Please check the File existence again.") - else: # If weights file are not passed - base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - base_dir = os.path.join(os.path.dirname(base_dir), 'BrainMaGe/weights') - weights = os.path.join(base_dir, 'resunet_ma.pt') + base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights") + weights = os.path.join(base_dir, "resunet_ma.pt") print("Weight file used :", weights) - if args.save_brain == 0: - args.save_brain = False - elif args.save_brain == 1: - args.save_brain = True - else: - raise ValueError('Unknown value for save brain : ', args.save_brain) - - # Creating a Dictionary - hparams = {} - if os.path.isfile(args.input): - if os.path.exists(args.input) - subjects = [os.path.abspath(args.input)] - else: - print("The File for the subject does not exist!") - elif os.path.isdir(args.input): - if os.path.exists(args.input): - subjects = glob.glob(os.path.join()) - hparams['subjects'] = subjects # Running Inference - test_single_run.infer_single_ma(input_subjects, output_path) + test_single_run.infer_single_ma(input_path, output_path, weights, mask_path, device=DEVICE) - print('*'*80) + print("*" * 80) print("Finished") - print('*'*80) + print("*" * 80) diff --git a/setup.py b/setup.py index 5914959..5b73743 100755 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ 'nibabel', 'pytorch-lightning==0.8.1' ], - scripts=['brain_mage_run', 'brain_mage_intensity_standardize'], + scripts=['brain_mage_run', 'brain_mage_single_run', 'brain_mage_intensity_standardize'], classifiers=[ 'Intended Audience :: Science/Research', 'Programming Language :: Python', From 0ab27950248bf32b312f5e4db1402299aadba972 Mon Sep 17 00:00:00 2001 From: Siddhesh Thakur Date: Thu, 29 Apr 2021 03:38:38 +0530 Subject: [PATCH 02/13] refactored to follow PEP8 --- BrainMaGe/models/networks.py | 110 +++--- BrainMaGe/models/seg_modules.py | 442 +++++++++++++++-------- BrainMaGe/tester/__init__.py | 2 - BrainMaGe/tester/test_ma.py | 209 +++++++---- BrainMaGe/tester/test_ma_multi.py | 121 ++++--- BrainMaGe/tester/test_multi_4.py | 97 +++-- BrainMaGe/tester/test_single_run.py | 2 +- BrainMaGe/trainer/lightning_networks.py | 97 ++--- BrainMaGe/trainer/trainer_main.py | 164 +++++---- BrainMaGe/utils/convert_ckpt_to_pt.py | 29 +- BrainMaGe/utils/csv_creator_adv.py | 170 +++++---- BrainMaGe/utils/cyclicLR.py | 39 +- BrainMaGe/utils/data.py | 24 +- BrainMaGe/utils/intensity_standardize.py | 274 +++++++++----- BrainMaGe/utils/losses.py | 51 +-- BrainMaGe/utils/optimizers.py | 42 ++- BrainMaGe/utils/preprocess.py | 149 +++++--- BrainMaGe/utils/utils_test.py | 41 ++- brain_mage_intensity_standardize | 274 +++++++++----- brain_mage_run | 266 ++++++++------ brain_mage_single_run | 5 +- setup.py | 57 +-- 22 files changed, 1684 insertions(+), 981 deletions(-) diff --git a/BrainMaGe/models/networks.py b/BrainMaGe/models/networks.py index 7c9d37a..b892628 100755 --- a/BrainMaGe/models/networks.py +++ b/BrainMaGe/models/networks.py @@ -19,22 +19,22 @@ def __init__(self, n_channels, n_classes, base_filters=16): self.n_channels = n_channels self.n_classes = n_classes self.ins = in_conv(self.n_channels, base_filters) - self.ds_0 = DownsamplingModule(base_filters, base_filters*2) - self.en_1 = EncodingModule(base_filters*2, base_filters*2) - self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4) - self.en_2 = EncodingModule(base_filters*4, base_filters*4) - self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8) - self.en_3 = EncodingModule(base_filters*8, base_filters*8) - self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16) - self.en_4 = EncodingModule(base_filters*16, base_filters*16) - self.us_3 = UpsamplingModule(base_filters*16, base_filters*8) - self.de_3 = DecodingModule(base_filters*16, base_filters*8) - self.us_2 = UpsamplingModule(base_filters*8, base_filters*4) - self.de_2 = DecodingModule(base_filters*8, base_filters*4) - self.us_1 = UpsamplingModule(base_filters*4, base_filters*2) - self.de_1 = DecodingModule(base_filters*4, base_filters*2) - self.us_0 = UpsamplingModule(base_filters*2, 16) - self.out = out_conv(base_filters*2, self.n_classes-1) + self.ds_0 = DownsamplingModule(base_filters, base_filters * 2) + self.en_1 = EncodingModule(base_filters * 2, base_filters * 2) + self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4) + self.en_2 = EncodingModule(base_filters * 4, base_filters * 4) + self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8) + self.en_3 = EncodingModule(base_filters * 8, base_filters * 8) + self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16) + self.en_4 = EncodingModule(base_filters * 16, base_filters * 16) + self.us_3 = UpsamplingModule(base_filters * 16, base_filters * 8) + self.de_3 = DecodingModule(base_filters * 16, base_filters * 8) + self.us_2 = UpsamplingModule(base_filters * 8, base_filters * 4) + self.de_2 = DecodingModule(base_filters * 8, base_filters * 4) + self.us_1 = UpsamplingModule(base_filters * 4, base_filters * 2) + self.de_1 = DecodingModule(base_filters * 4, base_filters * 2) + self.us_0 = UpsamplingModule(base_filters * 2, 16) + self.out = out_conv(base_filters * 2, self.n_classes - 1) def forward(self, x): x1 = self.ins(x) @@ -64,22 +64,22 @@ def __init__(self, n_channels, n_classes, base_filters=16): self.n_channels = n_channels self.n_classes = n_classes self.ins = in_conv(self.n_channels, base_filters, res=True) - self.ds_0 = DownsamplingModule(base_filters, base_filters*2) - self.en_1 = EncodingModule(base_filters*2, base_filters*2, res=True) - self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4) - self.en_2 = EncodingModule(base_filters*4, base_filters*4, res=True) - self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8) - self.en_3 = EncodingModule(base_filters*8, base_filters*8, res=True) - self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16) - self.en_4 = EncodingModule(base_filters*16, base_filters*16, res=True) - self.us_3 = UpsamplingModule(base_filters*16, base_filters*8) - self.de_3 = DecodingModule(base_filters*16, base_filters*8, res=True) - self.us_2 = UpsamplingModule(base_filters*8, base_filters*4) - self.de_2 = DecodingModule(base_filters*8, base_filters*4, res=True) - self.us_1 = UpsamplingModule(base_filters*4, base_filters*2) - self.de_1 = DecodingModule(base_filters*4, base_filters*2, res=True) - self.us_0 = UpsamplingModule(base_filters*2, base_filters) - self.out = out_conv(base_filters*2, self.n_classes-1, res=True) + self.ds_0 = DownsamplingModule(base_filters, base_filters * 2) + self.en_1 = EncodingModule(base_filters * 2, base_filters * 2, res=True) + self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4) + self.en_2 = EncodingModule(base_filters * 4, base_filters * 4, res=True) + self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8) + self.en_3 = EncodingModule(base_filters * 8, base_filters * 8, res=True) + self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16) + self.en_4 = EncodingModule(base_filters * 16, base_filters * 16, res=True) + self.us_3 = UpsamplingModule(base_filters * 16, base_filters * 8) + self.de_3 = DecodingModule(base_filters * 16, base_filters * 8, res=True) + self.us_2 = UpsamplingModule(base_filters * 8, base_filters * 4) + self.de_2 = DecodingModule(base_filters * 8, base_filters * 4, res=True) + self.us_1 = UpsamplingModule(base_filters * 4, base_filters * 2) + self.de_1 = DecodingModule(base_filters * 4, base_filters * 2, res=True) + self.us_0 = UpsamplingModule(base_filters * 2, base_filters) + self.out = out_conv(base_filters * 2, self.n_classes - 1, res=True) def forward(self, x): x1 = self.ins(x) @@ -109,21 +109,27 @@ def __init__(self, n_channels, n_classes, base_filters=16): self.n_channels = n_channels self.n_classes = n_classes self.ins = in_conv(self.n_channels, base_filters) - self.ds_0 = DownsamplingModule(base_filters, base_filters*2) - self.en_1 = EncodingModule(base_filters*2, base_filters*2) - self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4) - self.en_2 = EncodingModule(base_filters*4, base_filters*4) - self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8) - self.en_3 = EncodingModule(base_filters*8, base_filters*8) - self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16) - self.en_4 = EncodingModule(base_filters*16, base_filters*16) - self.us_4 = FCNUpsamplingModule(base_filters*16, 1, scale_factor=5) - self.us_3 = FCNUpsamplingModule(base_filters*8, 1, scale_factor=4) - self.us_2 = FCNUpsamplingModule(base_filters*4, 1, scale_factor=3) - self.us_1 = FCNUpsamplingModule(base_filters*2, 1, scale_factor=2) + self.ds_0 = DownsamplingModule(base_filters, base_filters * 2) + self.en_1 = EncodingModule(base_filters * 2, base_filters * 2) + self.ds_1 = DownsamplingModule(base_filters * 2, base_filters * 4) + self.en_2 = EncodingModule(base_filters * 4, base_filters * 4) + self.ds_2 = DownsamplingModule(base_filters * 4, base_filters * 8) + self.en_3 = EncodingModule(base_filters * 8, base_filters * 8) + self.ds_3 = DownsamplingModule(base_filters * 8, base_filters * 16) + self.en_4 = EncodingModule(base_filters * 16, base_filters * 16) + self.us_4 = FCNUpsamplingModule(base_filters * 16, 1, scale_factor=5) + self.us_3 = FCNUpsamplingModule(base_filters * 8, 1, scale_factor=4) + self.us_2 = FCNUpsamplingModule(base_filters * 4, 1, scale_factor=3) + self.us_1 = FCNUpsamplingModule(base_filters * 2, 1, scale_factor=2) self.us_0 = FCNUpsamplingModule(base_filters, 1, scale_factor=1) - self.conv_0 = nn.Conv3d(in_channels=5, out_channels=self.n_classes-1, - kernel_size=1, stride=1, padding=0, bias=True) + self.conv_0 = nn.Conv3d( + in_channels=5, + out_channels=self.n_classes - 1, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) def forward(self, x): x1 = self.ins(x) @@ -147,13 +153,15 @@ def forward(self, x): def fetch_model(modelname, num_channels, num_classes, num_filters): - if modelname == 'resunet': + if modelname == "resunet": model = resunet(num_channels, num_classes, num_filters) - elif modelname == 'unet': + elif modelname == "unet": model = resunet(num_channels, num_classes, num_filters) - elif modelname == 'fcn': + elif modelname == "fcn": model = fcn(num_channels, num_classes, num_filters) else: - raise ValueError('Check Model spelling, should be one of resunet, unet, fcn in the config'+\ - 'file!') + raise ValueError( + "Check Model spelling, should be one of resunet, unet, fcn in the config" + + "file!" + ) return model diff --git a/BrainMaGe/models/seg_modules.py b/BrainMaGe/models/seg_modules.py index ff7c862..ab9e4d9 100755 --- a/BrainMaGe/models/seg_modules.py +++ b/BrainMaGe/models/seg_modules.py @@ -11,10 +11,20 @@ import torch.nn as nn import torch.nn.functional as F + class in_conv(nn.Module): - def __init__(self, input_channels, output_channels, kernel_size=3, - dropout_p=0.3, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, res=False, lrelu_inplace=True): + def __init__( + self, + input_channels, + output_channels, + kernel_size=3, + dropout_p=0.3, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + res=False, + lrelu_inplace=True, + ): """[The initial convolution to enter the network, kind of like encode] [This function will create the input convolution] Arguments: @@ -40,21 +50,36 @@ def __init__(self, input_channels, output_channels, kernel_size=3, self.inst_norm_affine = inst_norm_affine self.lrelu_inplace = lrelu_inplace self.dropout = nn.Dropout3d(dropout_p) - self.in_0 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_1 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) + self.in_0 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_1 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv0 = nn.Conv3d( + input_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv1 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv2 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) def forward(self, x): """The forward function for initial convolution @@ -69,22 +94,33 @@ def forward(self, x): x = self.conv0(x) if self.residual: skip = x - x = F.leaky_relu(self.in_0(x), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.in_0(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv1(x) if self.dropout_p is not None and self.dropout_p > 0: x = self.dropout(x) - x = F.leaky_relu(self.in_1(x), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.in_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv2(x) if self.residual: x = x + skip return x + class DownsamplingModule(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, - dropout_p=0.3, kernel_size=3, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True): + def __init__( + self, + input_channels, + output_channels, + leakiness=1e-2, + dropout_p=0.3, + kernel_size=3, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): """[To Downsample a given input with convolution operation] [This one will be used to downsample a given comvolution while doubling the number filters] @@ -101,19 +137,24 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] (default: {True}) """ - #nn.Module.__init__(self) + # nn.Module.__init__(self) super(DownsamplingModule, self).__init__() self.dropout_p = dropout_p self.conv_bias = conv_bias self.leakiness = leakiness self.inst_norm_affine = inst_norm_affine self.lrelu_inplace = lrelu_inplace - self.in_0 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=3, - stride=2, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) + self.in_0 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv0 = nn.Conv3d( + input_channels, + output_channels, + kernel_size=3, + stride=2, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) def forward(self, x): """[This is a forward function for ] @@ -123,35 +164,47 @@ def forward(self, x): Returns: [Tensor] -- [Returns a torch Tensor] """ - x = F.leaky_relu(self.in_0(self.conv0(x)), - negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.in_0(self.conv0(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) return x + class EncodingModule(nn.Module): - def __init__(self, input_channels, output_channels, kernel_size=3, - dropout_p=0.3, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, res=False, lrelu_inplace=True): + def __init__( + self, + input_channels, + output_channels, + kernel_size=3, + dropout_p=0.3, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + res=False, + lrelu_inplace=True, + ): """[The Encoding convolution module to learn the information and use] - [This function will create the Learning convolutions] - Arguments: - input_channels {[int]} -- [the input number of channels, in our - case the number of channels from - downsample] - output_channels {[int]} -- [the output number of channels, will - determine the upcoming channels] - Keyword Arguments: - kernel_size {number} -- [size of filter] (default: {3}) - dropout_p {number} -- [dropout probablity] (default: {0.3}) - leakiness {number} -- [the negative leakiness] - (default: {1e-2}) - conv_bias {bool} -- [to use the bias in filters] - (default: {True}) - inst_norm_affine {bool} -- [affine use in norm] - (default: {True}) - res {bool} -- [to use residual connections] (default: {False}) - lrelu_inplace {bool} -- [To update conv outputs with lrelu - outputs] (default: {True}) + [This function will create the Learning convolutions] + Arguments: + input_channels {[int]} -- [the input number of channels, in our + case the number of channels from + downsample] + output_channels {[int]} -- [the output number of channels, will + determine the upcoming channels] + Keyword Arguments: + kernel_size {number} -- [size of filter] (default: {3}) + dropout_p {number} -- [dropout probablity] (default: {0.3}) + leakiness {number} -- [the negative leakiness] + (default: {1e-2}) + conv_bias {bool} -- [to use the bias in filters] + (default: {True}) + inst_norm_affine {bool} -- [affine use in norm] + (default: {True}) + res {bool} -- [to use residual connections] (default: {False}) + lrelu_inplace {bool} -- [To update conv outputs with lrelu + outputs] (default: {True}) """ nn.Module.__init__(self) self.res = res @@ -161,18 +214,28 @@ def __init__(self, input_channels, output_channels, kernel_size=3, self.inst_norm_affine = inst_norm_affine self.lrelu_inplace = lrelu_inplace self.dropout = nn.Dropout3d(dropout_p) - self.in_0 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_1 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.conv0 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) + self.in_0 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_1 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv0 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv1 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) def forward(self, x): """The forward function for initial convolution @@ -186,21 +249,25 @@ def forward(self, x): """ if self.res: skip = x - x = F.leaky_relu(self.in_0(x), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.in_0(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv0(x) if self.dropout_p is not None and self.dropout_p > 0: x = self.dropout(x) - x = F.leaky_relu(self.in_1(x), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.in_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv1(x) if self.res: x = x + skip return x + class Interpolate(nn.Module): - def __init__(self, size=None, scale_factor=None, mode='nearest', - align_corners=True): + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=True + ): super(Interpolate, self).__init__() self.align_corners = align_corners self.mode = mode @@ -208,15 +275,27 @@ def __init__(self, size=None, scale_factor=None, mode='nearest', self.size = size def forward(self, x): - return nn.functional.interpolate(x, size=self.size, - scale_factor=self.scale_factor, - mode=self.mode, - align_corners=self.align_corners) + return nn.functional.interpolate( + x, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + class UpsamplingModule(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, - lrelu_inplace=True, kernel_size=3, scale_factor=2, - conv_bias=True, inst_norm_affine=True): + def __init__( + self, + input_channels, + output_channels, + leakiness=1e-2, + lrelu_inplace=True, + kernel_size=3, + scale_factor=2, + conv_bias=True, + inst_norm_affine=True, + ): """[summary] [description] Arguments: @@ -236,10 +315,17 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, self.conv_bias = conv_bias self.leakiness = leakiness self.scale_factor = scale_factor - self.interpolate = Interpolate(scale_factor=self.scale_factor, - mode='trilinear', align_corners=True) - self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=1, - stride=1, padding=0, bias=self.conv_bias) + self.interpolate = Interpolate( + scale_factor=self.scale_factor, mode="trilinear", align_corners=True + ) + self.conv0 = nn.Conv3d( + input_channels, + output_channels, + kernel_size=1, + stride=1, + padding=0, + bias=self.conv_bias, + ) def forward(self, x): """[summary] @@ -249,10 +335,19 @@ def forward(self, x): x = self.conv0(self.interpolate(x)) return x + class FCNUpsamplingModule(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, - lrelu_inplace=True, kernel_size=3, scale_factor=2, - conv_bias=True, inst_norm_affine=True): + def __init__( + self, + input_channels, + output_channels, + leakiness=1e-2, + lrelu_inplace=True, + kernel_size=3, + scale_factor=2, + conv_bias=True, + inst_norm_affine=True, + ): """[summary] [description] Arguments: @@ -272,10 +367,19 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, self.conv_bias = conv_bias self.leakiness = leakiness self.scale_factor = scale_factor - self.interpolate = Interpolate(scale_factor=2**(self.scale_factor-1), - mode='trilinear', align_corners=True) - self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=1, - stride=1, padding=0, bias=self.conv_bias) + self.interpolate = Interpolate( + scale_factor=2 ** (self.scale_factor - 1), + mode="trilinear", + align_corners=True, + ) + self.conv0 = nn.Conv3d( + input_channels, + output_channels, + kernel_size=1, + stride=1, + padding=0, + bias=self.conv_bias, + ) def forward(self, x): """[summary] @@ -285,10 +389,19 @@ def forward(self, x): x = self.interpolate(self.conv0(x)) return x + class DecodingModule(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, - conv_bias=True, kernel_size=3, inst_norm_affine=True, - res=True, lrelu_inplace=True): + def __init__( + self, + input_channels, + output_channels, + leakiness=1e-2, + conv_bias=True, + kernel_size=3, + inst_norm_affine=True, + res=True, + lrelu_inplace=True, + ): """[The Decoding convolution module to learn the information and use later] [This function will create the Learning convolutions] @@ -313,24 +426,39 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, self.conv_bias = conv_bias self.leakiness = leakiness self.res = res - self.in_0 = nn.InstanceNorm3d(input_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_1 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_2 = nn.InstanceNorm3d(output_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size=3, - stride=1, padding=(kernel_size - 1) // 2, - bias=self.conv_bias) + self.in_0 = nn.InstanceNorm3d( + input_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_1 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_2 = nn.InstanceNorm3d( + output_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv0 = nn.Conv3d( + input_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv1 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv2 = nn.Conv3d( + output_channels, + output_channels, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) @@ -345,10 +473,19 @@ def forward(self, x1, x2): x = x + skip return x + class out_conv(nn.Module): - def __init__(self, input_channels, output_channels, leakiness=1e-2, - kernel_size=3, conv_bias=True, inst_norm_affine=True, - res=True, lrelu_inplace=True): + def __init__( + self, + input_channels, + output_channels, + leakiness=1e-2, + kernel_size=3, + conv_bias=True, + inst_norm_affine=True, + res=True, + lrelu_inplace=True, + ): """[The Out convolution module to learn the information and use later] [This function will create the Learning convolutions] Arguments: @@ -372,33 +509,50 @@ def __init__(self, input_channels, output_channels, leakiness=1e-2, self.conv_bias = conv_bias self.leakiness = leakiness self.res = res - self.in_0 = nn.InstanceNorm3d(input_channels, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_1 = nn.InstanceNorm3d(input_channels//2, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_2 = nn.InstanceNorm3d(input_channels//2, - affine=self.inst_norm_affine, - track_running_stats=True) - self.in_3 = nn.InstanceNorm3d(input_channels//2, - affine=self.inst_norm_affine, - track_running_stats=True) - self.conv0 = nn.Conv3d(input_channels, input_channels//2, - kernel_size=3, stride=1, - padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv1 = nn.Conv3d(input_channels//2, input_channels//2, - kernel_size=3, stride=1, - padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv2 = nn.Conv3d(input_channels//2, input_channels//2, - kernel_size=3, stride=1, - padding=(kernel_size - 1) // 2, - bias=self.conv_bias) - self.conv3 = nn.Conv3d(input_channels//2, output_channels, - kernel_size=1, stride=1, padding=0, - bias=self.conv_bias) + self.in_0 = nn.InstanceNorm3d( + input_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_1 = nn.InstanceNorm3d( + input_channels // 2, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_2 = nn.InstanceNorm3d( + input_channels // 2, affine=self.inst_norm_affine, track_running_stats=True + ) + self.in_3 = nn.InstanceNorm3d( + input_channels // 2, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv0 = nn.Conv3d( + input_channels, + input_channels // 2, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv1 = nn.Conv3d( + input_channels // 2, + input_channels // 2, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv2 = nn.Conv3d( + input_channels // 2, + input_channels // 2, + kernel_size=3, + stride=1, + padding=(kernel_size - 1) // 2, + bias=self.conv_bias, + ) + self.conv3 = nn.Conv3d( + input_channels // 2, + output_channels, + kernel_size=1, + stride=1, + padding=0, + bias=self.conv_bias, + ) def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) diff --git a/BrainMaGe/tester/__init__.py b/BrainMaGe/tester/__init__.py index 8bc4f8a..7d348ee 100755 --- a/BrainMaGe/tester/__init__.py +++ b/BrainMaGe/tester/__init__.py @@ -5,5 +5,3 @@ @author: siddhesh """ - - diff --git a/BrainMaGe/tester/test_ma.py b/BrainMaGe/tester/test_ma.py index 65f6d58..c9b2506 100755 --- a/BrainMaGe/tester/test_ma.py +++ b/BrainMaGe/tester/test_ma.py @@ -20,8 +20,13 @@ from scipy.ndimage.morphology import binary_fill_holes from BrainMaGe.models.networks import fetch_model from BrainMaGe.utils import csv_creator_adv -from BrainMaGe.utils.utils_test import pad_image, process_image, interpolate_image,\ - padder_and_cropper +from BrainMaGe.utils.utils_test import ( + pad_image, + process_image, + interpolate_image, + padder_and_cropper, +) + def postprocess_prediction(seg): mask = seg != 0 @@ -31,21 +36,27 @@ def postprocess_prediction(seg): seg[lbls != largest_region] = 0 return seg + def infer_ma(cfg, device, save_brain, weights): cfg = os.path.abspath(cfg) if os.path.isfile(cfg): - params_df = pd.read_csv(cfg, sep=' = ', names=['param_name', 'param_value'], - comment='#', skip_blank_lines=True, - engine='python').fillna(' ') + params_df = pd.read_csv( + cfg, + sep=" = ", + names=["param_name", "param_value"], + comment="#", + skip_blank_lines=True, + engine="python", + ).fillna(" ") else: - print('Missing test_params.cfg file? Please give one!') + print("Missing test_params.cfg file? Please give one!") sys.exit(0) params = {} for i in range(params_df.shape[0]): params[params_df.iloc[i, 0]] = params_df.iloc[i, 1] - params['weights'] = weights + params["weights"] = weights start = time.asctime() startstamp = time.time() print("\nHostname :" + str(os.getenv("HOSTNAME"))) @@ -54,21 +65,24 @@ def infer_ma(cfg, device, save_brain, weights): sys.stdout.flush() print("Generating Test csv") - if not os.path.exists(os.path.join(params['results_dir'])): - os.mkdir(params['results_dir']) - if not params['csv_provided'] == 'True': - print('Since CSV were not provided, we are gonna create for you') - csv_creator_adv.generate_csv(params['test_dir'], - to_save=params['results_dir'], - mode=params['mode'], ftype='test', - modalities=params['modalities']) - test_csv = os.path.join(params['results_dir'], 'test.csv') + if not os.path.exists(os.path.join(params["results_dir"])): + os.mkdir(params["results_dir"]) + if not params["csv_provided"] == "True": + print("Since CSV were not provided, we are gonna create for you") + csv_creator_adv.generate_csv( + params["test_dir"], + to_save=params["results_dir"], + mode=params["mode"], + ftype="test", + modalities=params["modalities"], + ) + test_csv = os.path.join(params["results_dir"], "test.csv") else: - test_csv = params['test_csv'] + test_csv = params["test_csv"] test_df = pd.read_csv(test_csv) test_df.ID = test_df.ID.astype(str) - temp_dir = os.path.join(params['results_dir'], 'Temp') + temp_dir = os.path.join(params["results_dir"], "Temp") os.makedirs(temp_dir, exist_ok=True) patients_dict = {} @@ -83,70 +97,84 @@ def infer_ma(cfg, device, save_brain, weights): old_affine = image.affine old_shape = image.header.get_data_shape() new_spacing = (1, 1, 1) - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) image_data = image.get_fdata() - new_image = resize(image_data, new_shape, order=3, mode='edge', cval=0, - anti_aliasing=False) + new_image = resize( + image_data, new_shape, order=3, mode="edge", cval=0, anti_aliasing=False + ) new_affine = np.eye(4) new_affine = np.array(old_affine) for i in range(3): for j in range(3): if old_affine[i, j] != 0: - new_affine[i, j] = old_affine[i, j]*(1/old_affine[i, j]) + new_affine[i, j] = old_affine[i, j] * (1 / old_affine[i, j]) if old_affine[i, j] <= 0: - new_affine[i, j] = -1*(old_affine[i, j]*(1/old_affine[i, j])) + new_affine[i, j] = -1 * ( + old_affine[i, j] * (1 / old_affine[i, j]) + ) temp_image = nib.Nifti1Image(new_image, new_affine) - nib.save(temp_image, os.path.join(temp_dir, patient[0], - patient[0]+'_resamp111.nii.gz')) + nib.save( + temp_image, + os.path.join(temp_dir, patient[0], patient[0] + "_resamp111.nii.gz"), + ) temp_dict = {} - temp_dict['name'] = patient[0] - temp_dict['old_spacing'] = old_spacing - temp_dict['old_affine'] = old_affine - temp_dict['old_shape'] = old_shape - temp_dict['new_spacing'] = new_spacing - temp_dict['new_affine'] = new_affine - temp_dict['new_shape'] = new_shape - - patient_path = os.path.join(temp_dir, patient[0], - patient[0]+'_resamp111.nii.gz') + temp_dict["name"] = patient[0] + temp_dict["old_spacing"] = old_spacing + temp_dict["old_affine"] = old_affine + temp_dict["old_shape"] = old_shape + temp_dict["new_spacing"] = new_spacing + temp_dict["new_affine"] = new_affine + temp_dict["new_shape"] = new_shape + + patient_path = os.path.join( + temp_dir, patient[0], patient[0] + "_resamp111.nii.gz" + ) patient_nib = nib.load(patient_path) patient_data = patient_nib.get_fdata() patient_data, pad_info = pad_image(patient_data) patient_affine = patient_nib.affine temp_image = nib.Nifti1Image(patient_data, patient_affine) - nib.save(temp_image, os.path.join(temp_dir, patient[0], patient[0]+'_bratsized.nii.gz')) - temp_dict['pad_info'] = pad_info + nib.save( + temp_image, + os.path.join(temp_dir, patient[0], patient[0] + "_bratsized.nii.gz"), + ) + temp_dict["pad_info"] = pad_info patients_dict[patient[0]] = temp_dict - model = fetch_model(params['model'], - int(params['num_modalities']), - int(params['num_classes']), - int(params['base_filters'])) - checkpoint = torch.load(str(params['weights'])) - model.load_state_dict(checkpoint['model_state_dict']) + model = fetch_model( + params["model"], + int(params["num_modalities"]), + int(params["num_classes"]), + int(params["base_filters"]), + ) + checkpoint = torch.load(str(params["weights"])) + model.load_state_dict(checkpoint["model_state_dict"]) - if device != 'cpu': + if device != "cpu": model.cuda() model.eval() - print("Done Resampling the Data.\n") - print("--"*30) + print("--" * 30) print("Running the model on the subjects") for patient in tqdm.tqdm(test_df.values): - patient_path = os.path.join(temp_dir, patient[0], - patient[0]+'_bratsized.nii.gz') + patient_path = os.path.join( + temp_dir, patient[0], patient[0] + "_bratsized.nii.gz" + ) patient_nib = nib.load(patient_path) image = patient_nib.get_fdata() image = process_image(image) - image = resize(image, (128, 128, 128), order=3, mode='edge', cval=0, - anti_aliasing=False) + image = resize( + image, (128, 128, 128), order=3, mode="edge", cval=0, anti_aliasing=False + ) image = image[np.newaxis, np.newaxis, ...] image = torch.FloatTensor(image) - if device != 'cpu': + if device != "cpu": image = image.cuda() with torch.no_grad(): output = model(image) @@ -155,31 +183,46 @@ def infer_ma(cfg, device, save_brain, weights): to_save[to_save >= 0.9] = 1 to_save[to_save < 0.9] = 0 to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine) - nib.save(to_save_nib, os.path.join(temp_dir, - patient[0], - patient[0]+'_bratsized_mask.nii.gz')) + nib.save( + to_save_nib, + os.path.join( + temp_dir, patient[0], patient[0] + "_bratsized_mask.nii.gz" + ), + ) current_patient_dict = patients_dict[patient[0]] - new_image = padder_and_cropper(to_save, current_patient_dict['pad_info']) + new_image = padder_and_cropper(to_save, current_patient_dict["pad_info"]) to_save_new_nib = nib.Nifti1Image(new_image, patient_nib.affine) - nib.save(to_save_new_nib, os.path.join(temp_dir, - patient[0], - patient[0]+'_resample111_mask.nii.gz')) - to_save_final = resize(new_image, current_patient_dict['old_shape'], order=3, - mode='edge', cval=0) + nib.save( + to_save_new_nib, + os.path.join( + temp_dir, patient[0], patient[0] + "_resample111_mask.nii.gz" + ), + ) + to_save_final = resize( + new_image, + current_patient_dict["old_shape"], + order=3, + mode="edge", + cval=0, + ) to_save_final[to_save_final > 0.9] = 1 to_save_final[to_save_final < 0.9] = 0 for i in range(to_save_final.shape[2]): if np.any(to_save_final[:, :, i]): to_save_final[:, :, i] = binary_fill_holes(to_save_final[:, :, i]) to_save_final = postprocess_prediction(to_save_final).astype(np.uint8) - to_save_final_nib = nib.Nifti1Image(to_save_final, - current_patient_dict['old_affine']) + to_save_final_nib = nib.Nifti1Image( + to_save_final, current_patient_dict["old_affine"] + ) - os.makedirs(os.path.join(params['results_dir'], patient[0]), exist_ok=True) + os.makedirs(os.path.join(params["results_dir"], patient[0]), exist_ok=True) - nib.save(to_save_final_nib, os.path.join(params['results_dir'], - patient[0], - patient[0]+'_mask.nii.gz')) + nib.save( + to_save_final_nib, + os.path.join( + params["results_dir"], patient[0], patient[0] + "_mask.nii.gz" + ), + ) print("Done with running the model.") if save_brain: @@ -187,18 +230,26 @@ def infer_ma(cfg, device, save_brain, weights): for patient in tqdm.tqdm(test_df.values): image = nib.load(patient[1]) image_data = image.get_fdata() - mask = nib.load(os.path.join(params['results_dir'], - patient[0], - patient[0]+'_mask.nii.gz')) + mask = nib.load( + os.path.join( + params["results_dir"], patient[0], patient[0] + "_mask.nii.gz" + ) + ) mask_data = mask.get_fdata().astype(np.int8) image_data[mask_data == 0] = 0 to_save_brain = nib.Nifti1Image(image_data, image.affine) - nib.save(to_save_brain, os.path.join(params['results_dir'], - patient[0], - patient[0]+'_brain.nii.gz')) - - print("Please check the %s folder for the intermediate outputs if you\"+\ - would like to see some intermediate steps." % (os.path.join(params['results_dir'], 'Temp'))) - print("Final output stored in : %s" % (params['results_dir'])) + nib.save( + to_save_brain, + os.path.join( + params["results_dir"], patient[0], patient[0] + "_brain.nii.gz" + ), + ) + + print( + 'Please check the %s folder for the intermediate outputs if you"+\ + would like to see some intermediate steps.' + % (os.path.join(params["results_dir"], "Temp")) + ) + print("Final output stored in : %s" % (params["results_dir"])) print("Thank you for using BrainMaGe") - print('*'*60) + print("*" * 60) diff --git a/BrainMaGe/tester/test_ma_multi.py b/BrainMaGe/tester/test_ma_multi.py index e2ecf5c..a9b659e 100755 --- a/BrainMaGe/tester/test_ma_multi.py +++ b/BrainMaGe/tester/test_ma_multi.py @@ -20,8 +20,12 @@ from scipy.ndimage.morphology import binary_fill_holes from BrainMaGe.models.networks import fetch_model from BrainMaGe.utils import csv_creator_adv -from BrainMaGe.utils.utils_test import pad_image, process_image, interpolate_image,\ - padder_and_cropper +from BrainMaGe.utils.utils_test import ( + pad_image, + process_image, + interpolate_image, + padder_and_cropper, +) from multiprocessing import cpu_count, Pool @@ -37,23 +41,29 @@ def postprocess_prediction(seg): pass return seg + def _intensity_standardize(image): new_image_temp = image[image >= image.mean()] p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) image[image > p2] = p2 - new_image = (image - p1)/p2 + new_image = (image - p1) / p2 return new_image.astype(np.float32) + def _preprocess_data(image_src): image = nib.load(image_src).get_fdata() - image = resize(image, (128, 128, 128), order=3, mode='edge', cval=0, - anti_aliasing=False) + image = resize( + image, (128, 128, 128), order=3, mode="edge", cval=0, anti_aliasing=False + ) image = _intensity_standardize(image) return image + def preprocess_batch_works(k): - sub_patients_path = [patients_path[i] for i in range(len(patients_path)) if i % n_processes == k] + sub_patients_path = [ + patients_path[i] for i in range(len(patients_path)) if i % n_processes == k + ] sub_patients = [patients[i] for i in range(len(patients)) if i % n_processes == k] for patient, patient_path in zip(sub_patients, sub_patients_path): patient_output_src = os.path.join(preprocessed_output_dir, patient) @@ -69,31 +79,36 @@ def _save_post_hard(patient_output, output_shape, output_affine, output_dir): except: pass patient_output = np.array(patient_output, dtype=np.float32) - patient_mask = resize(patient_output, output_shape, order=3, - preserve_range=True) + patient_mask = resize(patient_output, output_shape, order=3, preserve_range=True) patient_mask_post = (patient_mask > 0.5).astype(np.int8) to_save_post = nib.Nifti1Image(patient_mask_post, output_affine) - nib.save(to_save_post, os.path.join(output_dir, - os.path.basename(output_dir)+'_mask.nii.gz')) + nib.save( + to_save_post, + os.path.join(output_dir, os.path.basename(output_dir) + "_mask.nii.gz"), + ) + def process_output(patient_output_src, patient_orig_path, output_dir): patient_nib = nib.load(patient_orig_path) patient_orig_shape = patient_nib.shape patient_orig_affine = patient_nib.affine - patient_output = np.load(patient_output_src)['output'] - _save_post_hard(patient_output, patient_orig_shape, patient_orig_affine, - output_dir) + patient_output = np.load(patient_output_src)["output"] + _save_post_hard(patient_output, patient_orig_shape, patient_orig_affine, output_dir) + def postprocess_batch_works(k): - sub_patients_path = [patients_path[i] for i in range(len(patients_path)) if i % n_processes == k] + sub_patients_path = [ + patients_path[i] for i in range(len(patients_path)) if i % n_processes == k + ] sub_patients = [patients[i] for i in range(len(patients)) if i % n_processes == k] for patient, patient_path in zip(sub_patients, sub_patients_path): patient_dir = os.path.join(model_dir, patient) os.makedirs(patient_dir, exist_ok=True) - patient_temp_output_src = os.path.join(temp_output_dir, patient+'.npz') + patient_temp_output_src = os.path.join(temp_output_dir, patient + ".npz") print(patient, patient_temp_output_src, patient_path) process_output(patient_temp_output_src, patient_path, patient_dir) + def infer_ma(hparams): global patients, patients_path, n_processes, model_dir, preprocessed_output_dir, temp_output_dir model_dir = hparams.model_dir @@ -117,51 +132,57 @@ def infer_ma(hparams): print("Number of classes :", hparams.num_classes) print("Base Filters :", hparams.base_filters) print("Load Weights :", hparams.weights) - + print("Generating Test csv") - if not os.path.exists(hparams['results_dir']): + if not os.path.exists(hparams["results_dir"]): os.mkdir(hparams.results_dir) - if not hparams.csv_provided == 'True': - print('Since CSV were not provided, we are gonna create for you') - csv_creator_adv.generate_csv(hparams.test_dir, - to_save=hparams.results_dir, - mode=hparams.mode, - ftype='test', - modalities=hparams.modalities) - test_csv = os.path.join(hparams.results_dir, 'test.csv') + if not hparams.csv_provided == "True": + print("Since CSV were not provided, we are gonna create for you") + csv_creator_adv.generate_csv( + hparams.test_dir, + to_save=hparams.results_dir, + mode=hparams.mode, + ftype="test", + modalities=hparams.modalities, + ) + test_csv = os.path.join(hparams.results_dir, "test.csv") else: test_csv = hparams.test_csv n_processes = int(hparams.threads) - model = fetch_model(hparams.model, - int(hparams.num_modalities), - int(hparams.num_classes), - int(hparams.base_filters)) + model = fetch_model( + hparams.model, + int(hparams.num_modalities), + int(hparams.num_classes), + int(hparams.base_filters), + ) checkpoint = torch.load(str(hparams.weights)) model.load_state_dict(checkpoint.model_state_dict) - if hparams.device != 'cpu': + if hparams.device != "cpu": model.cuda() model.eval() test_df = pd.read_csv(test_csv) - preprocessed_output_dir = os.path.join(hparams.model_dir, 'preprocessed') + preprocessed_output_dir = os.path.join(hparams.model_dir, "preprocessed") os.makedirs(preprocessed_output_dir, exist_ok=True) patients = test_df.iloc[:, 0].astype(str) patients_path = test_df.iloc[:, 1] n_processes = int(hparams.threads) if len(patients) < n_processes: print("\n*********** WARNING ***********") - print("You are processing less number of patients as compared to the\n"+ - "threads provided, which means you are asking for more resources than \n"+ - "necessary which is not a great practice. Anyway, we have accounted for that \n"+ - "and reduced the number of threads to the maximum number of patients for \n"+ - "better resource management!\n") + print( + "You are processing less number of patients as compared to the\n" + + "threads provided, which means you are asking for more resources than \n" + + "necessary which is not a great practice. Anyway, we have accounted for that \n" + + "and reduced the number of threads to the maximum number of patients for \n" + + "better resource management!\n" + ) n_processes = len(patients) - print('*'*80) - print('Intializing preprocessing') - print('*'*80) - print("Initiating the CPU workload on %d threads.\n\n"%n_processes) + print("*" * 80) + print("Intializing preprocessing") + print("*" * 80) + print("Initiating the CPU workload on %d threads.\n\n" % n_processes) print("Currently processing the following patients : ") START = time.time() pool = Pool(processes=n_processes) @@ -170,16 +191,20 @@ def infer_ma(hparams): print("\n\n Preprocessing time taken : {} seconds".format(END - START)) # Load the preprocessed patients to the dataloader - print('*'*80) - print('Intializing Deep Neural Network') - print('*'*80) + print("*" * 80) + print("Intializing Deep Neural Network") + print("*" * 80) START = time.time() print("Initiating the GPU workload on CUDA threads.\n\n") print("Currently processing the following patients : ") - preprocessed_data_dir = os.path.join(hparams.model_dir, 'preprocessed') - temp_output_dir = os.path.join(hparams.model_dir, 'temp_output') + preprocessed_data_dir = os.path.join(hparams.model_dir, "preprocessed") + temp_output_dir = os.path.join(hparams.model_dir, "temp_output") os.makedirs(temp_output_dir, exist_ok=True) dataset_infer = VolSegDatasetInfer(preprocessed_data_dir) - infer_loader = DataLoader(dataset_infer, batch_size=int(hparams.batch_size), - shuffle=False, num_workers=int(hparams.threads), - pin_memory=False) \ No newline at end of file + infer_loader = DataLoader( + dataset_infer, + batch_size=int(hparams.batch_size), + shuffle=False, + num_workers=int(hparams.threads), + pin_memory=False, + ) diff --git a/BrainMaGe/tester/test_multi_4.py b/BrainMaGe/tester/test_multi_4.py index 5c5aafa..c34bd5c 100755 --- a/BrainMaGe/tester/test_multi_4.py +++ b/BrainMaGe/tester/test_multi_4.py @@ -31,6 +31,7 @@ def postprocess_prediction(seg): seg[lbls != largest_region] = 0 return seg + def infer_multi_4(cfg, device, save_brain, weights): """ Inference using multi modality network @@ -52,14 +53,19 @@ def infer_multi_4(cfg, device, save_brain, weights): cfg = os.path.abspath(cfg) if os.path.isfile(cfg): - params_df = pd.read_csv(cfg, sep=' = ', names=['param_name', 'param_value'], - comment='#', skip_blank_lines=True, - engine='python').fillna(' ') + params_df = pd.read_csv( + cfg, + sep=" = ", + names=["param_name", "param_value"], + comment="#", + skip_blank_lines=True, + engine="python", + ).fillna(" ") else: - print('Missing test_params.cfg file? Please give one!') + print("Missing test_params.cfg file? Please give one!") sys.exit(0) params = {} - params['weights'] = weights + params["weights"] = weights for i in range(params_df.shape[0]): params[params_df.iloc[i, 0]] = params_df.iloc[i, 1] start = time.asctime() @@ -70,44 +76,49 @@ def infer_multi_4(cfg, device, save_brain, weights): sys.stdout.flush() print("Generating Test csv") - if not os.path.exists(os.path.join(params['results_dir'])): - os.mkdir(params['results_dir']) - if not params['csv_provided'] == 'True': - print('Since CSV were not provided, we are gonna create for you') - csv_creator_adv.generate_csv(params['test_dir'], - to_save=params['results_dir'], - mode=params['mode'], ftype='test', - modalities=params['modalities']) - test_csv = os.path.join(params['results_dir'], 'test.csv') + if not os.path.exists(os.path.join(params["results_dir"])): + os.mkdir(params["results_dir"]) + if not params["csv_provided"] == "True": + print("Since CSV were not provided, we are gonna create for you") + csv_creator_adv.generate_csv( + params["test_dir"], + to_save=params["results_dir"], + mode=params["mode"], + ftype="test", + modalities=params["modalities"], + ) + test_csv = os.path.join(params["results_dir"], "test.csv") else: - test_csv = params['test_csv'] + test_csv = params["test_csv"] test_df = pd.read_csv(test_csv) - model = fetch_model(params['model'], - int(params['num_modalities']), - int(params['num_classes']), - int(params['base_filters'])) - if device != 'cpu': + model = fetch_model( + params["model"], + int(params["num_modalities"]), + int(params["num_classes"]), + int(params["base_filters"]), + ) + if device != "cpu": model.cuda() - checkpoint = torch.load(str(params['weights'])) - model.load_state_dict(checkpoint['model_state_dict']) + checkpoint = torch.load(str(params["weights"])) + model.load_state_dict(checkpoint["model_state_dict"]) model.eval() for patient in tqdm.tqdm(test_df.values): - os.makedirs(os.path.join(params['results_dir'], patient[0]), exist_ok=True) - nmods = params['num_modalities'] + os.makedirs(os.path.join(params["results_dir"], patient[0]), exist_ok=True) + nmods = params["num_modalities"] stack = np.zeros([int(nmods), 128, 128, 128], dtype=np.float32) for i in range(int(nmods)): - image_path = patient[i+1] + image_path = patient[i + 1] patient_nib = nib.load(image_path) image = patient_nib.get_fdata() image = preprocess_image(patient_nib) stack[i] = image stack = stack[np.newaxis, ...] stack = torch.FloatTensor(stack) - if device != 'cpu': + if device != "cpu": image = stack.cuda() with torch.no_grad(): output = model(image) @@ -121,27 +132,37 @@ def infer_multi_4(cfg, device, save_brain, weights): to_save[:, :, i] = binary_fill_holes(to_save[:, :, i]) to_save = postprocess_prediction(to_save).astype(np.uint8) to_save_mask = nib.Nifti1Image(to_save, patient_nib.affine) - nib.save(to_save_mask, os.path.join(params['results_dir'], patient[0], - patient[0]+'_mask.nii.gz')) + nib.save( + to_save_mask, + os.path.join( + params["results_dir"], patient[0], patient[0] + "_mask.nii.gz" + ), + ) print("Done with running the model.") if save_brain: print("You chose to save the brain. We are now saving it with the masks.") for patient in tqdm.tqdm(test_df.values): - nmods = params['num_modalities'] - mask_nib = nib.load(os.path.join(params['results_dir'], patient[0], - patient[0]+'_mask.nii.gz')) + nmods = params["num_modalities"] + mask_nib = nib.load( + os.path.join( + params["results_dir"], patient[0], patient[0] + "_mask.nii.gz" + ) + ) mask_data = mask_nib.get_fdata().astype(np.int8) for i in range(int(nmods)): - image_name = os.path.basename(patient[i+1]).strip('.nii.gz') - image_path = patient[i+1] + image_name = os.path.basename(patient[i + 1]).strip(".nii.gz") + image_path = patient[i + 1] patient_nib = nib.load(image_path) image_data = patient_nib.get_fdata() image_data[mask_data == 0] = 0 to_save_image = nib.Nifti1Image(image_data, patient_nib.affine) - nib.save(to_save_image, os.path.join(params['results_dir'], - patient[0], - image_name+'_brain.nii.gz')) + nib.save( + to_save_image, + os.path.join( + params["results_dir"], patient[0], image_name + "_brain.nii.gz" + ), + ) - print("Final output stored in : %s" % (params['results_dir'])) + print("Final output stored in : %s" % (params["results_dir"])) print("Thank you for using BrainMaGe") - print('*'*60) + print("*" * 60) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index ecc2653..d312e38 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -36,7 +36,7 @@ def postprocess_prediction(seg): return seg -def infer_single_ma(input_path, output_path, weights, mask_path=None, device='cpu'): +def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cpu"): start = time.asctime() startstamp = time.time() print("\nHostname :" + str(os.getenv("HOSTNAME"))) diff --git a/BrainMaGe/trainer/lightning_networks.py b/BrainMaGe/trainer/lightning_networks.py index c39dac2..7a1efc1 100755 --- a/BrainMaGe/trainer/lightning_networks.py +++ b/BrainMaGe/trainer/lightning_networks.py @@ -20,10 +20,12 @@ class SkullStripper(ptl.LightningModule): def __init__(self, params): super(SkullStripper, self).__init__() self.params = params - self.model = fetch_model(params['model'], - int(self.params['num_modalities']), - int(self.params['num_classes']), - int(self.params['base_filters'])) + self.model = fetch_model( + params["model"], + int(self.params["num_modalities"]), + int(self.params["num_classes"]), + int(self.params["base_filters"]), + ) def forward(self, x): return self.model(x) @@ -33,62 +35,73 @@ def my_loss(self, output, mask): return loss def training_step(self, batch, batch_nb): - image, mask = batch['image_data'], batch['ground_truth_data'] + image, mask = batch["image_data"], batch["ground_truth_data"] output = self.forward(image) loss = self.my_loss(output, mask) dice_score = dice(output, mask) - return {'loss': loss, - 'dice': dice_score} + return {"loss": loss, "dice": dice_score} def validation_step(self, batch, batch_nb): - image, mask = batch['image_data'], batch['ground_truth_data'] + image, mask = batch["image_data"], batch["ground_truth_data"] output = self.forward(image) loss = self.my_loss(output, mask) dice_score = dice(output, mask) - return {'val_loss': loss, - 'val_dice': dice_score} + return {"val_loss": loss, "val_dice": dice_score} def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - avg_dice = torch.stack([x['val_dice'] for x in outputs]).mean() - logs = {'avg_val_loss': avg_loss, 'avg_val_dice': avg_dice} - print("Average validation loss :", avg_loss, "Average validation dice", avg_dice) - return {'val_loss': avg_loss, - 'val_dice': avg_dice, - 'progress_bar': logs, - 'log': logs} + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + avg_dice = torch.stack([x["val_dice"] for x in outputs]).mean() + logs = {"avg_val_loss": avg_loss, "avg_val_dice": avg_dice} + print( + "Average validation loss :", avg_loss, "Average validation dice", avg_dice + ) + return { + "val_loss": avg_loss, + "val_dice": avg_dice, + "progress_bar": logs, + "log": logs, + } def configure_optimizers(self): # Setting up the optimizer - optimizer = fetch_optimizer(self.params['optimizer'], - self.params['learning_rate'], - self.model) + optimizer = fetch_optimizer( + self.params["optimizer"], self.params["learning_rate"], self.model + ) # Setting up an optimizer lr reducer - lr_milestones = [int(i) - for i in self.params['lr_milestones'][1:-1].split(',')] - decay_milestones = [int(i) - for i in - self.params['decay_milestones'][1:-1].split(',')] - scheduler = CyclicCosAnnealingLR(optimizer, - milestones=lr_milestones, - decay_milestones=decay_milestones, - eta_min=5e-6) + lr_milestones = [int(i) for i in self.params["lr_milestones"][1:-1].split(",")] + decay_milestones = [ + int(i) for i in self.params["decay_milestones"][1:-1].split(",") + ] + scheduler = CyclicCosAnnealingLR( + optimizer, + milestones=lr_milestones, + decay_milestones=decay_milestones, + eta_min=5e-6, + ) return [optimizer], [scheduler] @ptl.data_loader def train_dataloader(self): - dataset_train = SkullStripDataset(self.params['train_csv'], self.params, - test=False) - return DataLoader(dataset_train, - batch_size=int(self.params['batch_size']), - shuffle=True, num_workers=4, - pin_memory=True) + dataset_train = SkullStripDataset( + self.params["train_csv"], self.params, test=False + ) + return DataLoader( + dataset_train, + batch_size=int(self.params["batch_size"]), + shuffle=True, + num_workers=4, + pin_memory=True, + ) @ptl.data_loader def val_dataloader(self): - dataset_valid = SkullStripDataset(self.params['validation_csv'], self.params, - test=False) - return DataLoader(dataset_valid, - batch_size=int(self.params['batch_size']), - shuffle=False, num_workers=4, - pin_memory=True) + dataset_valid = SkullStripDataset( + self.params["validation_csv"], self.params, test=False + ) + return DataLoader( + dataset_valid, + batch_size=int(self.params["batch_size"]), + shuffle=False, + num_workers=4, + pin_memory=True, + ) diff --git a/BrainMaGe/trainer/trainer_main.py b/BrainMaGe/trainer/trainer_main.py index ecf0b8b..5a4c563 100755 --- a/BrainMaGe/trainer/trainer_main.py +++ b/BrainMaGe/trainer/trainer_main.py @@ -17,6 +17,7 @@ from BrainMaGe.utils.csv_creator_adv import generate_csv from BrainMaGe.trainer.lightning_networks import SkullStripper + def train_network(cfg, device, weights): """ Receiving a configuration file and a device, the training is pushed through this file @@ -42,11 +43,16 @@ def train_network(cfg, device, weights): print("Checking for this cfg file : ", cfg) # READING FROM A CFG FILE and check if file exists or not if os.path.isfile(cfg): - params_df = pd.read_csv(cfg, sep=' = ', names=['param_name', 'param_value'], - comment='#', skip_blank_lines=True, - engine='python').fillna(' ') + params_df = pd.read_csv( + cfg, + sep=" = ", + names=["param_name", "param_value"], + comment="#", + skip_blank_lines=True, + engine="python", + ).fillna(" ") else: - print('Missing train_params.cfg file?') + print("Missing train_params.cfg file?") sys.exit(0) # Reading in all the parameters @@ -55,98 +61,106 @@ def train_network(cfg, device, weights): params[params_df.iloc[i, 0]] = params_df.iloc[i, 1] print(type(device), device) if type(device) != str: - params['device'] = str(device) - params['weights'] = weights + params["device"] = str(device) + params["weights"] = weights # Although uneccessary, we still do this - if not os.path.isdir(str(params['model_dir'])): - os.mkdir(params['model_dir']) + if not os.path.isdir(str(params["model_dir"])): + os.mkdir(params["model_dir"]) # PRINT PARSED ARGS print("\n\n") - print("Training Folder Dir :", params['train_dir']) - print("Validation Dir :", params['validation_dir']) - print("Model Directory :", params['model_dir']) - print("Mode :", params['mode']) - print("Number of modalities :", params['num_modalities']) - print("Modalities :", params['modalities']) - print("Number of classes :", params['num_classes']) - print("Max Number of epochs :", params['max_epochs']) - print("Batch size :", params['batch_size']) - print("Optimizer :", params['optimizer']) - print("Learning Rate :", params['learning_rate']) - print("Learning Rate Milestones:", params['lr_milestones']) - print("Patience to decay :", params['decay_milestones']) - print("Early Stopping Patience :", params['early_stop_patience']) - print("Depth Layers :", params['layers']) - print("Model used :", params['model']) - print("Weights used :", params['weights']) + print("Training Folder Dir :", params["train_dir"]) + print("Validation Dir :", params["validation_dir"]) + print("Model Directory :", params["model_dir"]) + print("Mode :", params["mode"]) + print("Number of modalities :", params["num_modalities"]) + print("Modalities :", params["modalities"]) + print("Number of classes :", params["num_classes"]) + print("Max Number of epochs :", params["max_epochs"]) + print("Batch size :", params["batch_size"]) + print("Optimizer :", params["optimizer"]) + print("Learning Rate :", params["learning_rate"]) + print("Learning Rate Milestones:", params["lr_milestones"]) + print("Patience to decay :", params["decay_milestones"]) + print("Early Stopping Patience :", params["early_stop_patience"]) + print("Depth Layers :", params["layers"]) + print("Model used :", params["model"]) + print("Weights used :", params["weights"]) sys.stdout.flush() print("Device Given :", device) sys.stdout.flush() # Although uneccessary, we still do this - os.makedirs(params['model_dir'], exist_ok=True) + os.makedirs(params["model_dir"], exist_ok=True) print("Current Device : ", torch.cuda.current_device()) print("Device Count on Machine : ", torch.cuda.device_count()) print("Device Name : ", torch.cuda.get_device_name()) print("Cuda Availibility : ", torch.cuda.is_available()) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print('Using device:', device) - if device.type == 'cuda': - print('Memory Usage:') - print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), - 'GB') - print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3, 1), 'GB') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Using device:", device) + if device.type == "cuda": + print("Memory Usage:") + print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), "GB") + print("Cached: ", round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), "GB") sys.stdout.flush() # We generate CSV for training if not provided print("Generating CSV Files") # Generating training csv files - if not params['csv_provided'] == 'True': - print('Since CSV were not provided, we are gonna create for you') - generate_csv(params['train_dir'], - to_save=params['model_dir'], - mode=params['mode'], ftype='train', - modalities=params['modalities']) - generate_csv(params['validation_dir'], - to_save=params['model_dir'], - mode=params['mode'], ftype='validation', - modalities=params['modalities']) - params['train_csv'] = os.path.join(params['model_dir'], 'train.csv') - params['validation_csv'] = os.path.join(params['model_dir'], 'validation.csv') + if not params["csv_provided"] == "True": + print("Since CSV were not provided, we are gonna create for you") + generate_csv( + params["train_dir"], + to_save=params["model_dir"], + mode=params["mode"], + ftype="train", + modalities=params["modalities"], + ) + generate_csv( + params["validation_dir"], + to_save=params["model_dir"], + mode=params["mode"], + ftype="validation", + modalities=params["modalities"], + ) + params["train_csv"] = os.path.join(params["model_dir"], "train.csv") + params["validation_csv"] = os.path.join(params["model_dir"], "validation.csv") else: # Taken directly from params pass - os.makedirs(params['model_dir'], exist_ok=True) - - log_dir = os.path.join(params['model_dir']) - checkpoint_callback = ModelCheckpoint(filepath=os.path.join(log_dir, - 'checkpoints'), - monitor='val_loss', - verbose=True, - save_top_k=1, - mode='auto', - save_weights_only=False, - prefix=str('deep_resunet_'+ - params['base_filters']) - ) - stop_callback = EarlyStopping(monitor='val_loss', mode='auto', - patience=int(params['early_stop_patience']), - verbose=True) + os.makedirs(params["model_dir"], exist_ok=True) + + log_dir = os.path.join(params["model_dir"]) + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(log_dir, "checkpoints"), + monitor="val_loss", + verbose=True, + save_top_k=1, + mode="auto", + save_weights_only=False, + prefix=str("deep_resunet_" + params["base_filters"]), + ) + stop_callback = EarlyStopping( + monitor="val_loss", + mode="auto", + patience=int(params["early_stop_patience"]), + verbose=True, + ) model = SkullStripper(params) res_ckpt = weights - trainer = Trainer(checkpoint_callback=checkpoint_callback, - early_stop_callback=stop_callback, - default_root_dir=params['model_dir'], - gpus=params['device'], - fast_dev_run=False, - max_epochs=int(params['max_epochs']), - min_epochs=int(params['min_epochs']), - distributed_backend='ddp', - weights_summary='full', - weights_save_path=params['model_dir'], - amp_level='O1', - num_sanity_val_steps=5, - resume_from_checkpoint=res_ckpt, - ) + trainer = Trainer( + checkpoint_callback=checkpoint_callback, + early_stop_callback=stop_callback, + default_root_dir=params["model_dir"], + gpus=params["device"], + fast_dev_run=False, + max_epochs=int(params["max_epochs"]), + min_epochs=int(params["min_epochs"]), + distributed_backend="ddp", + weights_summary="full", + weights_save_path=params["model_dir"], + amp_level="O1", + num_sanity_val_steps=5, + resume_from_checkpoint=res_ckpt, + ) trainer.fit(model) diff --git a/BrainMaGe/utils/convert_ckpt_to_pt.py b/BrainMaGe/utils/convert_ckpt_to_pt.py index 34f26d9..ec79cf3 100755 --- a/BrainMaGe/utils/convert_ckpt_to_pt.py +++ b/BrainMaGe/utils/convert_ckpt_to_pt.py @@ -10,14 +10,23 @@ from argparse import ArgumentParser import torch -if __name__ == '__main__': +if __name__ == "__main__": - parser = ArgumentParser(description='Convert the .ckpt files to .pt files') - parser.add_argument('-i', '--input', dest='input', - help='Input .ckpt file generated by lightning', required=True) - parser.add_argument('-o', '--output', dest='output', - help='Output .pt file to be generated, must be with extension of .pt', - required=True) + parser = ArgumentParser(description="Convert the .ckpt files to .pt files") + parser.add_argument( + "-i", + "--input", + dest="input", + help="Input .ckpt file generated by lightning", + required=True, + ) + parser.add_argument( + "-o", + "--output", + dest="output", + help="Output .pt file to be generated, must be with extension of .pt", + required=True, + ) args = parser.parse_args() ckpt_file = os.path.abspath(args.input) @@ -27,10 +36,10 @@ weight_load = torch.load(ckpt_file) print("Load Successful! Converting file.") new_state_dict = {} - for key in weight_load['state_dict'].keys(): + for key in weight_load["state_dict"].keys(): new_key = key[6:] - new_state_dict[new_key] = weight_load['state_dict'][key] - model_state_dict = {'model_state_dict' : new_state_dict} + new_state_dict[new_key] = weight_load["state_dict"][key] + model_state_dict = {"model_state_dict": new_state_dict} print("Conversion successful!") torch.save(model_state_dict, pt_file) print("File saved successfully at :", pt_file) diff --git a/BrainMaGe/utils/csv_creator_adv.py b/BrainMaGe/utils/csv_creator_adv.py index 2f17b48..6438d91 100755 --- a/BrainMaGe/utils/csv_creator_adv.py +++ b/BrainMaGe/utils/csv_creator_adv.py @@ -29,30 +29,34 @@ def rex_o4a_csv(folder_path, to_save, ftype, modalities): : ['t1', 't2', 't1ce']] """ modalities = modalities[1:-1] - modalities = re.findall('[^, \']+', modalities) + modalities = re.findall("[^, ']+", modalities) if not modalities: - print("Could not find modalities! Are you sure you have put in \ - something in the modalities field?") + print( + "Could not find modalities! Are you sure you have put in \ + something in the modalities field?" + ) sys.exit(0) - if ftype == 'test': - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,Image_Path\n') + if ftype == "test": + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,Image_Path\n") else: - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,gt_path,Image_path\n') + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,gt_path,Image_path\n") folders = os.listdir(folder_path) for folder in folders: for modality in modalities: - csv_file.write(folder+'_'+modality+',') - if ftype != 'test': - ground_truth = glob.glob(os.path.join(folder_path, folder, - '*mask.nii.gz'))[0] + csv_file.write(folder + "_" + modality + ",") + if ftype != "test": + ground_truth = glob.glob( + os.path.join(folder_path, folder, "*mask.nii.gz") + )[0] csv_file.write(ground_truth) - csv_file.write(',') - img = glob.glob(os.path.join(folder_path, folder, - '*'+modality+'.nii.gz'))[0] + csv_file.write(",") + img = glob.glob( + os.path.join(folder_path, folder, "*" + modality + ".nii.gz") + )[0] csv_file.write(img) - csv_file.write('\n') + csv_file.write("\n") csv_file.close() @@ -70,35 +74,39 @@ def rex_sin_csv(folder_path, to_save, ftype, modalities): : ['t1']] """ modalities = modalities[1:-1] - modalities = re.findall('[^, \']+', modalities) + modalities = re.findall("[^, ']+", modalities) if len(modalities) > 1: print("Found more than one modality, exiting!") sys.exit(0) if not modalities: - print("Could not find modalities! Are you sure you have put in \ - something in the modalities field?") + print( + "Could not find modalities! Are you sure you have put in \ + something in the modalities field?" + ) sys.exit(0) - if ftype == 'test': - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,') + if ftype == "test": + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,gt_path,') + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,gt_path,") modality = modalities[0] - csv_file.write(modality+'_path\n') + csv_file.write(modality + "_path\n") folders = os.listdir(folder_path) for folder in folders: csv_file.write(folder) - csv_file.write(',') - if ftype != 'test': - ground_truth = glob.glob(os.path.join(folder_path, folder, - '*mask.nii.gz'))[0] + csv_file.write(",") + if ftype != "test": + ground_truth = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[ + 0 + ] csv_file.write(ground_truth) - csv_file.write(',') - img = glob.glob(os.path.join(folder_path, folder, - '*'+modality+'.nii.gz'))[0] + csv_file.write(",") + img = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[ + 0 + ] csv_file.write(img) - csv_file.write('\n') + csv_file.write("\n") csv_file.close() @@ -116,40 +124,45 @@ def rex_mul_csv(folder_path, to_save, ftype, modalities): : ['t1']] """ modalities = modalities[1:-1] - modalities = re.findall('[^, \']+', modalities) + modalities = re.findall("[^, ']+", modalities) if not modalities: - print("Could not find modalities! Are you sure you have put in \ - something in the modalities field?") + print( + "Could not find modalities! Are you sure you have put in \ + something in the modalities field?" + ) sys.exit(0) - if ftype == 'test': - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,') + if ftype == "test": + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,gt_path,') + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,gt_path,") for modality in modalities[:-1]: - csv_file.write(modality+'_path,') + csv_file.write(modality + "_path,") modality = modalities[-1] - csv_file.write(modality+'_path\n') + csv_file.write(modality + "_path\n") folders = os.listdir(folder_path) for folder in folders: csv_file.write(folder) - csv_file.write(',') - if ftype != 'test': - ground_truth = glob.glob(os.path.join(folder_path, folder, - '*mask.nii.gz'))[0] + csv_file.write(",") + if ftype != "test": + ground_truth = glob.glob(os.path.join(folder_path, folder, "*mask.nii.gz"))[ + 0 + ] csv_file.write(ground_truth) - csv_file.write(',') + csv_file.write(",") for modality in modalities[:-1]: - img = glob.glob(os.path.join(folder_path, folder, - '*'+modality+'.nii.gz'))[0] + img = glob.glob( + os.path.join(folder_path, folder, "*" + modality + ".nii.gz") + )[0] csv_file.write(img) - csv_file.write(',') + csv_file.write(",") modality = modalities[-1] - img = glob.glob(os.path.join(folder_path, folder, - '*'+modality+'.nii.gz'))[0] + img = glob.glob(os.path.join(folder_path, folder, "*" + modality + ".nii.gz"))[ + 0 + ] csv_file.write(img) - csv_file.write('\n') + csv_file.write("\n") csv_file.close() @@ -164,20 +177,20 @@ def rex_bids_csv(folder_path, to_save, ftype): if file type is set to test, it does not look for ground truths] """ - if ftype == 'test': - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,') + if ftype == "test": + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,") else: - csv_file = open(os.path.join(to_save, ftype+'.csv'), 'w+') - csv_file.write('ID,gt_path,') + csv_file = open(os.path.join(to_save, ftype + ".csv"), "w+") + csv_file.write("ID,gt_path,") # load BIDS dataset into memory layout = BIDSLayout(folder_path) bids_df = layout.to_df() bids_modality_df = { - 't1': bids_df[bids_df['suffix'] == "T1w"], - 't2': bids_df[bids_df['suffix'] == "T2w"], - 'flair': bids_df[bids_df['suffix'] == "FLAIR"], - 't1ce': bids_df[bids_df['suffix'] == "T1CE"] + "t1": bids_df[bids_df["suffix"] == "T1w"], + "t2": bids_df[bids_df["suffix"] == "T2w"], + "flair": bids_df[bids_df["suffix"] == "FLAIR"], + "t1ce": bids_df[bids_df["suffix"] == "T1CE"], } # check what modalities the dataset contains modalities = [] @@ -186,27 +199,28 @@ def rex_bids_csv(folder_path, to_save, ftype): modalities.append(modality) # write headers for those modalities for modality in modalities[:-1]: - csv_file.write(modality+'_path,') + csv_file.write(modality + "_path,") modality = modalities[-1] - csv_file.write(modality+'_path\n') + csv_file.write(modality + "_path\n") # write image paths for each subject for sub in layout.get_subjects(): csv_file.write(sub) - csv_file.write(',') - if ftype != 'test': - ground_truth = glob.glob(os.path.join(folder_path, sub, '*mask.nii.gz'))[0] + csv_file.write(",") + if ftype != "test": + ground_truth = glob.glob(os.path.join(folder_path, sub, "*mask.nii.gz"))[0] csv_file.write(ground_truth) - csv_file.write(',') + csv_file.write(",") for modality in modalities[:-1]: - img = bids_modality_df[modality][bids_df['subject'] == sub].path.values + img = bids_modality_df[modality][bids_df["subject"] == sub].path.values csv_file.write(img[0]) - csv_file.write(',') + csv_file.write(",") modality = modalities[-1] - img = bids_modality_df[modality][bids_df['subject'] == sub].path.values + img = bids_modality_df[modality][bids_df["subject"] == sub].path.values csv_file.write(img[0]) - csv_file.write('\n') + csv_file.write("\n") csv_file.close() + def generate_csv(folder_path, to_save, mode, ftype, modalities): """[Function to generate CSV] [This function takes a look at the data directory and the modes and @@ -218,14 +232,14 @@ def generate_csv(folder_path, to_save, mode, ftype, modalities): ftype {[string]} -- [description] modalities {[string]} -- [description] """ - print("Generating ", ftype, '.csv', sep='') - if mode.lower() == 'ma': + print("Generating ", ftype, ".csv", sep="") + if mode.lower() == "ma": rex_o4a_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == 'single': + elif mode.lower() == "single": rex_sin_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == 'multi': + elif mode.lower() == "multi": rex_mul_csv(folder_path, to_save, ftype, modalities) - elif mode.lower() == 'bids': + elif mode.lower() == "bids": rex_bids_csv(folder_path, to_save, ftype) else: print("Sorry, this mode is not supported") diff --git a/BrainMaGe/utils/cyclicLR.py b/BrainMaGe/utils/cyclicLR.py index f11904f..ab78efc 100755 --- a/BrainMaGe/utils/cyclicLR.py +++ b/BrainMaGe/utils/cyclicLR.py @@ -35,11 +35,20 @@ class CyclicCosAnnealingLR(_LRScheduler): https://arxiv.org/abs/1608.03983 """ - def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, - eta_min=1e-6, last_epoch=-1): + def __init__( + self, + optimizer, + milestones, + decay_milestones=None, + gamma=0.5, + eta_min=1e-6, + last_epoch=-1, + ): if not list(milestones) == sorted(milestones): - raise ValueError('Milestones should be a list of' - ' increasing integers. Got {}', milestones) + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) self.eta_min = eta_min self.milestones = milestones self.milestones2 = decay_milestones @@ -52,11 +61,27 @@ def get_lr(self): return [self.eta_min for base_lr in self.base_lrs] idx = bisect_right(self.milestones, self.last_epoch) - left_barrier = 0 if idx == 0 else self.milestones[idx-1] + left_barrier = 0 if idx == 0 else self.milestones[idx - 1] right_barrier = self.milestones[idx] width = right_barrier - left_barrier curr_pos = self.last_epoch - left_barrier if self.milestones2: - return [self.eta_min + (base_lr*self.gamma**bisect_right(self.milestones2, self.last_epoch)-self.eta_min)*(1 + math.cos(math.pi * curr_pos/ width))/2 for base_lr in self.base_lrs] + return [ + self.eta_min + + ( + base_lr + * self.gamma ** bisect_right(self.milestones2, self.last_epoch) + - self.eta_min + ) + * (1 + math.cos(math.pi * curr_pos / width)) + / 2 + for base_lr in self.base_lrs + ] else: - return [self.eta_min + (base_lr - self.eta_min)*(1 + math.cos(math.pi * curr_pos/ width)) / 2 for base_lr in self.base_lrs] + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * curr_pos / width)) + / 2 + for base_lr in self.base_lrs + ] diff --git a/BrainMaGe/utils/data.py b/BrainMaGe/utils/data.py index 9470a36..68e519b 100755 --- a/BrainMaGe/utils/data.py +++ b/BrainMaGe/utils/data.py @@ -17,6 +17,7 @@ class SkullStripDataset(Dataset): """ Skull strip dataloader """ + def __init__(self, csv_file, params, test=False): self.df = pd.read_csv(csv_file, header=0) self.params = params @@ -30,26 +31,31 @@ def __getitem__(self, patient_id): if not self.test: ground_truth_path = os.path.join(self.df.iloc[patient_id, 1]) ground_truth = nib.load(ground_truth_path) - nmods = self.params['num_modalities'] + nmods = self.params["num_modalities"] stack = np.zeros([int(nmods), 128, 128, 128], dtype=np.float32) for i in range(int(nmods)): - image_path = os.path.join(self.df.iloc[patient_id, i+2]) + image_path = os.path.join(self.df.iloc[patient_id, i + 2]) image = nib.load(image_path) image_data = image.get_fdata().astype(np.float32)[np.newaxis, ...] stack[i] = image_data - ground_truth_data = ground_truth.get_data().astype(np.float32)[np.newaxis, ...] + ground_truth_data = ground_truth.get_data().astype(np.float32)[ + np.newaxis, ... + ] affine = image.affine - sample = {'image_name': image_name, 'image_data': stack, - 'ground_truth_data': ground_truth_data, 'affine': affine} + sample = { + "image_name": image_name, + "image_data": stack, + "ground_truth_data": ground_truth_data, + "affine": affine, + } else: - nmods = self.params['num_modalities'] + nmods = self.params["num_modalities"] stack = np.zeros([int(nmods), 128, 128, 128], dtype=np.float32) for i in range(int(nmods)): - image_path = os.path.join(self.df.iloc[patient_id, i+1]) + image_path = os.path.join(self.df.iloc[patient_id, i + 1]) image = nib.load(image_path) image_data = image.get_fdata().astype(np.float32)[np.newaxis, ...] stack[i] = image_data affine = image.affine - sample = {'image_name': image_name, 'image_data': stack, - 'affine': affine} + sample = {"image_name": image_name, "image_data": stack, "affine": affine} return sample diff --git a/BrainMaGe/utils/intensity_standardize.py b/BrainMaGe/utils/intensity_standardize.py index f168753..9728712 100755 --- a/BrainMaGe/utils/intensity_standardize.py +++ b/BrainMaGe/utils/intensity_standardize.py @@ -39,6 +39,7 @@ # *** The files will be generated as something_roimask.nii.gz *** # """ + def pad_image(image): """[To pad the image to particular space] [This function will pad the image to a space of [240, 240, 160] and will @@ -52,29 +53,42 @@ def pad_image(image): # Padding on X axes if image.shape[0] < 240: # print("Image was padded on the X-axis on both sides") - padded_image = np.pad(padded_image, ((int((240-image.shape[0])/2), - int((240-image.shape[0])/2)), - (0, 0), (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (int((240 - image.shape[0]) / 2), int((240 - image.shape[0]) / 2)), + (0, 0), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Y axes if image.shape[1] < 240: # print("Image was padded on the Y-axis on both sides") - padded_image = np.pad(padded_image, ((0, 0), - (int((240-image.shape[1])/2), - int((240-image.shape[1])/2)), - (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (0, 0), + (int((240 - image.shape[1]) / 2), int((240 - image.shape[1]) / 2)), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Z axes if image.shape[2] < 160: # print("Image was padded on the Z-axis on top only") - padded_image = np.pad(padded_image, ((0, 0), (0, 0), - (0, int(160-image.shape[2]))), - 'constant', constant_values=0) + padded_image = np.pad( + padded_image, + ((0, 0), (0, 0), (0, int(160 - image.shape[2]))), + "constant", + constant_values=0, + ) return padded_image -def preprocess_image(image, is_mask=False, - target_spacing=(1.875, 1.875, 1.25)): +def preprocess_image(image, is_mask=False, target_spacing=(1.875, 1.875, 1.25)): """[To preprocess an image depending on whether it a mask image or not] [This function in general will try to preprocess a given image to a partic- -ular image resolution and try to return a preprocessed image] @@ -109,8 +123,14 @@ def preprocess_image(image, is_mask=False, with an isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is an isotropic image with need to incorrect shape] @@ -126,8 +146,14 @@ def preprocess_image(image, is_mask=False, # print("Image shape wasn't perfect") new_image = pad_image(new_image) # print("Trying to pad the image now") - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is not isotropic image with resolution needed] ________________________________________ @@ -139,40 +165,82 @@ def preprocess_image(image, is_mask=False, a isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=1, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=1, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: if old_spacing == (1.0, 1.0, 1.0): if shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=0, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=0, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) if is_mask: # Retrun if mask return new_image.astype(np.int8) @@ -181,9 +249,10 @@ def preprocess_image(image, is_mask=False, p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) new_image[new_image > p2] = p2 - new_image = (new_image - p1)/p2 + new_image = (new_image - p1) / p2 return new_image.astype(np.float32) + def normalize(folder, dest_folder, patient_name, test=False): """[Function used to pre-process files] [This function is used for the skull stripping preprocessing, @@ -196,16 +265,14 @@ def normalize(folder, dest_folder, patient_name, test=False): """ patient_dest_folder = os.path.join(dest_folder, patient_name) os.makedirs(patient_dest_folder, exist_ok=True) - t1 = glob.glob(os.path.join(folder, '*t1.nii.gz'))[0] - t2 = glob.glob(os.path.join(folder, '*t2.nii.gz'))[0] - t1ce = glob.glob(os.path.join(folder, '*t1ce.nii.gz'))[0] - flair = glob.glob(os.path.join(folder, '*flair.nii.gz'))[0] + t1 = glob.glob(os.path.join(folder, "*t1.nii.gz"))[0] + t2 = glob.glob(os.path.join(folder, "*t2.nii.gz"))[0] + t1ce = glob.glob(os.path.join(folder, "*t1ce.nii.gz"))[0] + flair = glob.glob(os.path.join(folder, "*flair.nii.gz"))[0] if not test: - gt = glob.glob(os.path.join(folder, '*mask.nii.gz'))[0] + gt = glob.glob(os.path.join(folder, "*mask.nii.gz"))[0] - new_affine = np.array([[1.875, 0, 0], - [0, 1.875, 0], - [0, 0, 1.25]]) + new_affine = np.array([[1.875, 0, 0], [0, 1.875, 0], [0, 0, 1.25]]) # Reading T1 image and storing it t1_image = nib.load(t1) @@ -214,76 +281,109 @@ def normalize(folder, dest_folder, patient_name, test=False): temp_affine[:3, :3] = new_affine resized_t1_image = nib.Nifti1Image(resized_t1_image, temp_affine) print(patient_dest_folder) - print("Saving T1 at : ", os.path.join(patient_dest_folder, patient_name + - "_t1.nii.gz")) - nib.save(resized_t1_image, os.path.join(patient_dest_folder, patient_name + - "_t1.nii.gz")) + print( + "Saving T1 at : ", + os.path.join(patient_dest_folder, patient_name + "_t1.nii.gz"), + ) + nib.save( + resized_t1_image, os.path.join(patient_dest_folder, patient_name + "_t1.nii.gz") + ) t2_image = nib.load(t2) resized_t2_image = preprocess_image(t2_image, is_mask=False) temp_affine = t2_image.affine temp_affine[:3, :3] = new_affine resized_t2_image = nib.Nifti1Image(resized_t2_image, temp_affine) - nib.save(resized_t2_image, os.path.join(patient_dest_folder, patient_name + - "_t2.nii.gz")) + nib.save( + resized_t2_image, os.path.join(patient_dest_folder, patient_name + "_t2.nii.gz") + ) t1ce_image = nib.load(t1ce) resized_t1ce_image = preprocess_image(t1ce_image, is_mask=False) temp_affine = t1ce_image.affine temp_affine[:3, :3] = new_affine - resized_t1ce_image = nib.Nifti1Image(resized_t1ce_image, - t1ce_image.affine) - nib.save(resized_t1ce_image, os.path.join(patient_dest_folder, patient_name + - "_t1ce.nii.gz")) + resized_t1ce_image = nib.Nifti1Image(resized_t1ce_image, t1ce_image.affine) + nib.save( + resized_t1ce_image, + os.path.join(patient_dest_folder, patient_name + "_t1ce.nii.gz"), + ) flair_image = nib.load(flair) resized_flair_image = preprocess_image(flair_image, is_mask=False) temp_affine = flair_image.affine temp_affine[:3, :3] = new_affine - resized_flair_image = nib.Nifti1Image(resized_flair_image, - flair_image.affine) - nib.save(resized_flair_image, os.path.join(patient_dest_folder, patient_name - + - "_flair.nii.gz")) + resized_flair_image = nib.Nifti1Image(resized_flair_image, flair_image.affine) + nib.save( + resized_flair_image, + os.path.join(patient_dest_folder, patient_name + "_flair.nii.gz"), + ) if not test: gt_image = nib.load(gt) resized_gt_image = preprocess_image(gt_image, is_mask=True) - resized_gt_image = nib.Nifti1Image(resized_gt_image, - gt_image.affine) - nib.save(resized_gt_image, os.path.join(patient_dest_folder, patient_name - +"_mask.nii.gz")) + resized_gt_image = nib.Nifti1Image(resized_gt_image, gt_image.affine) + nib.save( + resized_gt_image, + os.path.join(patient_dest_folder, patient_name + "_mask.nii.gz"), + ) return def batch_works(k): if k == n_processes - 1: - sub_patients = patients[k * int(len(patients) / n_processes):] + sub_patients = patients[k * int(len(patients) / n_processes) :] else: - sub_patients = patients[k * int(len(patients) / n_processes): - (k + 1) * int(len(patients) / n_processes)] + sub_patients = patients[ + k + * int(len(patients) / n_processes) : (k + 1) + * int(len(patients) / n_processes) + ] for patient in sub_patients: patient_name = os.path.basename(patient) print(patient_name) normalize(patient, output_path, patient_name) -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='intensity_standardize', formatter_class=argparse.RawTextHelpFormatter, - description='\nThis code was implemented to standardize intensities for skull stripping\n'+ '\n'\ - 'Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n'\ - 'For questions and feedback contact: software@cbica.upenn.edu') +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="intensity_standardize", + formatter_class=argparse.RawTextHelpFormatter, + description="\nThis code was implemented to standardize intensities for skull stripping\n" + + "\n" + "Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n" + "For questions and feedback contact: software@cbica.upenn.edu", + ) + + parser.add_argument( + "-i", + "--input_path", + dest="input_path", + help="input path for the tissues", + required=True, + ) + parser.add_argument( + "-o", + "--output_path", + dest="output_path", + help="output path for saving the files", + required=True, + ) + parser.add_argument( + "-t", + "--threads", + dest="threads", + help="number of threads, by default will use all", + ) + + parser.add_argument( + "-v", + "--version", + action="version", + version=pkg_resources.require("BrainMaGe")[0].version + + "\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.", + help="Show program's version number and exit.", + ) - parser.add_argument('-i', '--input_path', dest='input_path', - help="input path for the tissues", required=True) - parser.add_argument('-o', '--output_path', dest='output_path', - help="output path for saving the files", required=True) - parser.add_argument('-t', '--threads', dest='threads', - help="number of threads, by default will use all") - - parser.add_argument('-v', '--version', action='version', - version=pkg_resources.require("BrainMaGe")[0].version + '\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.', help="Show program's version number and exit.") - args = parser.parse_args() if args.threads: @@ -295,7 +395,7 @@ def batch_works(k): input_path = os.path.abspath(args.input_path) output_path = os.path.abspath(args.output_path) os.makedirs(output_path, exist_ok=True) - patients = glob.glob(os.path.abspath(args.input_path)+'/*') + patients = glob.glob(os.path.abspath(args.input_path) + "/*") n_processes = cpu_count() pool = Pool(processes=n_processes) pool.map(batch_works, range(n_processes)) diff --git a/BrainMaGe/utils/losses.py b/BrainMaGe/utils/losses.py index 4a8f330..a8b61a3 100755 --- a/BrainMaGe/utils/losses.py +++ b/BrainMaGe/utils/losses.py @@ -8,61 +8,70 @@ import torch.nn as nn + def dice_loss(inp, target): smooth = 1e-7 iflat = inp.view(-1) tflat = target.view(-1) intersection = (iflat * tflat).sum() - return 1 - ((2. * intersection + smooth) / - (iflat.sum() + tflat.sum() + smooth)) + return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) + def dice(inp, target): smooth = 1e-7 iflat = inp.view(-1) tflat = target.view(-1) intersection = (iflat * tflat).sum() - return (2*intersection+smooth)/(iflat.sum()+tflat.sum()+smooth) + return (2 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) + def tversky(inp, target, alpha): smooth = 1e-7 iflat = inp.view(-1) tflat = target.view(-1) - intersection = (iflat*tflat).sum() - fps = (iflat * (1-tflat)).sum() - fns = ((1-iflat) * tflat).sum() - denominator = intersection + (alpha*fps) + ((1-alpha)*fns) + smooth - return (intersection+smooth)/denominator + intersection = (iflat * tflat).sum() + fps = (iflat * (1 - tflat)).sum() + fns = ((1 - iflat) * tflat).sum() + denominator = intersection + (alpha * fps) + ((1 - alpha) * fns) + smooth + return (intersection + smooth) / denominator + def tversky_loss(inp, target, alpha): smooth = 1e-7 iflat = inp.view(-1) tflat = inp.view(-1) - intersection = (iflat*tflat).sum() - fps = (inp * (1-target)).sum() - fns = (inp * (1-target)).sum() - denominator = intersection + (alpha*fps) + ((1-alpha)*fns) + smooth - return 1 - ((intersection+smooth)/denominator) + intersection = (iflat * tflat).sum() + fps = (inp * (1 - target)).sum() + fns = (inp * (1 - target)).sum() + denominator = intersection + (alpha * fps) + ((1 - alpha) * fns) + smooth + return 1 - ((intersection + smooth) / denominator) + def power_loss(inp, target, power): return dice_loss(inp, target) ** power + def pointwise_loss(inp, target, alpha=20, beta=3): iflat = inp.view(-1) tflat = target[:, 0, :, :].contiguous().view(-1) - intersection = (alpha*(iflat * tflat).pow(beta)).sum() + intersection = (alpha * (iflat * tflat).pow(beta)).sum() union = iflat.sum() + tflat.sum() - return 1 - intersection/union + return 1 - intersection / union def bce_loss(inp, target): return nn.BCELoss(inp.view(-1), target.view(-1)) -def focal_tversky_loss(inp, target, alpha=0.3, gamma=4/3): + +def focal_tversky_loss(inp, target, alpha=0.3, gamma=4 / 3): tver = tversky(inp, target, alpha) - tver_loss = 1-tver - return (tver_loss)**(1/gamma) + tver_loss = 1 - tver + return (tver_loss) ** (1 / gamma) + -def ft_bce_loss(inp, target, alpha=0.3, gamma=4/3): +def ft_bce_loss(inp, target, alpha=0.3, gamma=4 / 3): tver = tversky(inp, target, alpha) - tver_loss = 1-tver - return 0.5*((tver_loss)**(1/gamma) + nn.BCELoss(inp.view(-1), target.view(-1))) + tver_loss = 1 - tver + return 0.5 * ( + (tver_loss) ** (1 / gamma) + nn.BCELoss(inp.view(-1), target.view(-1)) + ) diff --git a/BrainMaGe/utils/optimizers.py b/BrainMaGe/utils/optimizers.py index 7799cd7..2bc12e1 100755 --- a/BrainMaGe/utils/optimizers.py +++ b/BrainMaGe/utils/optimizers.py @@ -10,28 +10,32 @@ import torch.optim as optim import sys + def fetch_optimizer(optimizer, lr, model): # Setting up the optimizer - if optimizer.lower() == 'sgd': - optimizer = optim.SGD(model.parameters(), - lr=float(lr), - momentum=0.9, nesterov=True) - elif optimizer.lower() == 'adam': - optimizer = optim.Adam(model.parameters(), - lr=float(lr), - betas=(0.9, 0.999), - weight_decay=0.00005) - elif optimizer.lower() == 'rms': - optimizer = optim.RMSprop(model.parameters(), - lr=float(lr), - momentum=0.9, weight_decay=0.00005) - elif optimizer.lower() == 'adagrad': - optimizer = optim.Adagrad(model.parameters(), - lr=float(lr), - weight_decay=0.00005) + if optimizer.lower() == "sgd": + optimizer = optim.SGD( + model.parameters(), lr=float(lr), momentum=0.9, nesterov=True + ) + elif optimizer.lower() == "adam": + optimizer = optim.Adam( + model.parameters(), lr=float(lr), betas=(0.9, 0.999), weight_decay=0.00005 + ) + elif optimizer.lower() == "rms": + optimizer = optim.RMSprop( + model.parameters(), lr=float(lr), momentum=0.9, weight_decay=0.00005 + ) + elif optimizer.lower() == "adagrad": + optimizer = optim.Adagrad( + model.parameters(), lr=float(lr), weight_decay=0.00005 + ) else: - print("Sorry, {} is not supported or some sort of spell error. Please\ - choose from the given options!".format(optimizer)) + print( + "Sorry, {} is not supported or some sort of spell error. Please\ + choose from the given options!".format( + optimizer + ) + ) sys.stdout.flush() sys.exit(0) return optimizer diff --git a/BrainMaGe/utils/preprocess.py b/BrainMaGe/utils/preprocess.py index 369a85b..b0a027d 100755 --- a/BrainMaGe/utils/preprocess.py +++ b/BrainMaGe/utils/preprocess.py @@ -24,29 +24,42 @@ def pad_image(image): # Padding on X axes if image.shape[0] < 240: # print("Image was padded on the X-axis on both sides") - padded_image = np.pad(padded_image, ((int((240-image.shape[0])/2), - int((240-image.shape[0])/2)), - (0, 0), (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (int((240 - image.shape[0]) / 2), int((240 - image.shape[0]) / 2)), + (0, 0), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Y axes if image.shape[1] < 240: # print("Image was padded on the Y-axis on both sides") - padded_image = np.pad(padded_image, ((0, 0), - (int((240-image.shape[1])/2), - int((240-image.shape[1])/2)), - (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (0, 0), + (int((240 - image.shape[1]) / 2), int((240 - image.shape[1]) / 2)), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Z axes if image.shape[2] < 160: # print("Image was padded on the Z-axis on top only") - padded_image = np.pad(padded_image, ((0, 0), (0, 0), - (0, int(160-image.shape[2]))), - 'constant', constant_values=0) + padded_image = np.pad( + padded_image, + ((0, 0), (0, 0), (0, int(160 - image.shape[2]))), + "constant", + constant_values=0, + ) return padded_image -def preprocess_image(image, is_mask=False, - target_spacing=(1.875, 1.875, 1.25)): +def preprocess_image(image, is_mask=False, target_spacing=(1.875, 1.875, 1.25)): """[To preprocess an image depending on whether it a mask image or not] [This function in general will try to preprocess a given image to a partic- -ular image resolution and try to return a preprocessed image] @@ -81,8 +94,14 @@ def preprocess_image(image, is_mask=False, with an isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is an isotropic image with need to incorrect shape] @@ -98,8 +117,14 @@ def preprocess_image(image, is_mask=False, # print("Image shape wasn't perfect") new_image = pad_image(new_image) # print("Trying to pad the image now") - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is not isotropic image with resolution needed] ________________________________________ @@ -111,40 +136,82 @@ def preprocess_image(image, is_mask=False, a isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=1, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=1, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: if old_spacing == (1.0, 1.0, 1.0): if shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=0, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=0, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) if is_mask: # Retrun if mask return new_image.astype(np.int8) @@ -153,5 +220,5 @@ def preprocess_image(image, is_mask=False, p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) new_image[new_image > p2] = p2 - new_image = (new_image - p1)/p2 + new_image = (new_image - p1) / p2 return new_image.astype(np.float32) diff --git a/BrainMaGe/utils/utils_test.py b/BrainMaGe/utils/utils_test.py index 76f14d0..8a8de1e 100755 --- a/BrainMaGe/utils/utils_test.py +++ b/BrainMaGe/utils/utils_test.py @@ -9,6 +9,7 @@ import numpy as np from skimage.transform import resize + def pad_image(image): """ @@ -28,25 +29,38 @@ def pad_image(image): # Padding on X axes if image.shape[0] <= 240: pad_x1 = (240 - image.shape[0]) // 2 - pad_x2 = (240 - image.shape[0] - pad_x1) - padded_image = np.pad(padded_image, ((pad_x1, pad_x2), (0, 0), (0, 0)), - mode='constant', constant_values=0) + pad_x2 = 240 - image.shape[0] - pad_x1 + padded_image = np.pad( + padded_image, + ((pad_x1, pad_x2), (0, 0), (0, 0)), + mode="constant", + constant_values=0, + ) # Padding on Y axes if image.shape[1] <= 240: pad_y1 = (240 - image.shape[1]) // 2 - pad_y2 = (240 - image.shape[1] - pad_y1) - padded_image = np.pad(padded_image, ((0, 0), (pad_y1, pad_y2), (0, 0)), - mode='constant', constant_values=0) + pad_y2 = 240 - image.shape[1] - pad_y1 + padded_image = np.pad( + padded_image, + ((0, 0), (pad_y1, pad_y2), (0, 0)), + mode="constant", + constant_values=0, + ) # Padding on Z axes if image.shape[2] <= 160: pad_z2 = 160 - image.shape[2] - padded_image = np.pad(padded_image, ((0, 0), (0, 0), (pad_z2, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ((0, 0), (0, 0), (pad_z2, 0)), + mode="constant", + constant_values=0, + ) return padded_image, ((pad_x1, pad_x2), (pad_y1, pad_y2), (pad_z1, pad_z2)) + def process_image(image): """ special percentile based preprocessing and then apply the stuff on image, @@ -67,9 +81,10 @@ def process_image(image): p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) image[image > p2] = p2 - image = (image - p1)/p2 + image = (image - p1) / p2 return image + def padder_and_cropper(image, pad_info): (pad_x1, pad_x2), (pad_y1, pad_y2), (pad_z1, pad_z2) = pad_info if pad_x2 == 0: @@ -78,13 +93,15 @@ def padder_and_cropper(image, pad_info): pad_y2 = -image.shape[1] if pad_z2 == 0: pad_z2 = -image.shape[2] - image = image[pad_x1: -pad_x2, pad_y1: -pad_y2, pad_z2:] + image = image[pad_x1:-pad_x2, pad_y1:-pad_y2, pad_z2:] return image + def unpad_image(image): image = image[:, :, :155] return image + def interpolate_image(image, output_shape): - new_image = resize(image, (output_shape), order=3, mode='edge', cval=0) - return new_image \ No newline at end of file + new_image = resize(image, (output_shape), order=3, mode="edge", cval=0) + return new_image diff --git a/brain_mage_intensity_standardize b/brain_mage_intensity_standardize index f168753..9728712 100755 --- a/brain_mage_intensity_standardize +++ b/brain_mage_intensity_standardize @@ -39,6 +39,7 @@ import pkg_resources # *** The files will be generated as something_roimask.nii.gz *** # """ + def pad_image(image): """[To pad the image to particular space] [This function will pad the image to a space of [240, 240, 160] and will @@ -52,29 +53,42 @@ def pad_image(image): # Padding on X axes if image.shape[0] < 240: # print("Image was padded on the X-axis on both sides") - padded_image = np.pad(padded_image, ((int((240-image.shape[0])/2), - int((240-image.shape[0])/2)), - (0, 0), (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (int((240 - image.shape[0]) / 2), int((240 - image.shape[0]) / 2)), + (0, 0), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Y axes if image.shape[1] < 240: # print("Image was padded on the Y-axis on both sides") - padded_image = np.pad(padded_image, ((0, 0), - (int((240-image.shape[1])/2), - int((240-image.shape[1])/2)), - (0, 0)), - mode='constant', constant_values=0) + padded_image = np.pad( + padded_image, + ( + (0, 0), + (int((240 - image.shape[1]) / 2), int((240 - image.shape[1]) / 2)), + (0, 0), + ), + mode="constant", + constant_values=0, + ) # Padding on Z axes if image.shape[2] < 160: # print("Image was padded on the Z-axis on top only") - padded_image = np.pad(padded_image, ((0, 0), (0, 0), - (0, int(160-image.shape[2]))), - 'constant', constant_values=0) + padded_image = np.pad( + padded_image, + ((0, 0), (0, 0), (0, int(160 - image.shape[2]))), + "constant", + constant_values=0, + ) return padded_image -def preprocess_image(image, is_mask=False, - target_spacing=(1.875, 1.875, 1.25)): +def preprocess_image(image, is_mask=False, target_spacing=(1.875, 1.875, 1.25)): """[To preprocess an image depending on whether it a mask image or not] [This function in general will try to preprocess a given image to a partic- -ular image resolution and try to return a preprocessed image] @@ -109,8 +123,14 @@ def preprocess_image(image, is_mask=False, with an isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is an isotropic image with need to incorrect shape] @@ -126,8 +146,14 @@ def preprocess_image(image, is_mask=False, # print("Image shape wasn't perfect") new_image = pad_image(new_image) # print("Trying to pad the image now") - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: """[Checking if it is not isotropic image with resolution needed] ________________________________________ @@ -139,40 +165,82 @@ def preprocess_image(image, is_mask=False, a isotropic resolution of (1.0, 1.0, 1.0), then we would just resize the image to (128, 128, 128)] """ - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=1, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=1, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=3, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=3, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: if old_spacing == (1.0, 1.0, 1.0): if shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - new_image = resize(new_image, new_shape, order=0, - mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + new_image = resize( + new_image, new_shape, order=0, mode="edge", cval=0, anti_aliasing=False + ) if new_shape == [240, 240, 160]: - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) else: new_image = pad_image(new_image) - new_image = resize(new_image, (128, 128, 128), order=0, - mode='edge', cval=0, anti_aliasing=False) + new_image = resize( + new_image, + (128, 128, 128), + order=0, + mode="edge", + cval=0, + anti_aliasing=False, + ) if is_mask: # Retrun if mask return new_image.astype(np.int8) @@ -181,9 +249,10 @@ def preprocess_image(image, is_mask=False, p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) new_image[new_image > p2] = p2 - new_image = (new_image - p1)/p2 + new_image = (new_image - p1) / p2 return new_image.astype(np.float32) + def normalize(folder, dest_folder, patient_name, test=False): """[Function used to pre-process files] [This function is used for the skull stripping preprocessing, @@ -196,16 +265,14 @@ def normalize(folder, dest_folder, patient_name, test=False): """ patient_dest_folder = os.path.join(dest_folder, patient_name) os.makedirs(patient_dest_folder, exist_ok=True) - t1 = glob.glob(os.path.join(folder, '*t1.nii.gz'))[0] - t2 = glob.glob(os.path.join(folder, '*t2.nii.gz'))[0] - t1ce = glob.glob(os.path.join(folder, '*t1ce.nii.gz'))[0] - flair = glob.glob(os.path.join(folder, '*flair.nii.gz'))[0] + t1 = glob.glob(os.path.join(folder, "*t1.nii.gz"))[0] + t2 = glob.glob(os.path.join(folder, "*t2.nii.gz"))[0] + t1ce = glob.glob(os.path.join(folder, "*t1ce.nii.gz"))[0] + flair = glob.glob(os.path.join(folder, "*flair.nii.gz"))[0] if not test: - gt = glob.glob(os.path.join(folder, '*mask.nii.gz'))[0] + gt = glob.glob(os.path.join(folder, "*mask.nii.gz"))[0] - new_affine = np.array([[1.875, 0, 0], - [0, 1.875, 0], - [0, 0, 1.25]]) + new_affine = np.array([[1.875, 0, 0], [0, 1.875, 0], [0, 0, 1.25]]) # Reading T1 image and storing it t1_image = nib.load(t1) @@ -214,76 +281,109 @@ def normalize(folder, dest_folder, patient_name, test=False): temp_affine[:3, :3] = new_affine resized_t1_image = nib.Nifti1Image(resized_t1_image, temp_affine) print(patient_dest_folder) - print("Saving T1 at : ", os.path.join(patient_dest_folder, patient_name + - "_t1.nii.gz")) - nib.save(resized_t1_image, os.path.join(patient_dest_folder, patient_name + - "_t1.nii.gz")) + print( + "Saving T1 at : ", + os.path.join(patient_dest_folder, patient_name + "_t1.nii.gz"), + ) + nib.save( + resized_t1_image, os.path.join(patient_dest_folder, patient_name + "_t1.nii.gz") + ) t2_image = nib.load(t2) resized_t2_image = preprocess_image(t2_image, is_mask=False) temp_affine = t2_image.affine temp_affine[:3, :3] = new_affine resized_t2_image = nib.Nifti1Image(resized_t2_image, temp_affine) - nib.save(resized_t2_image, os.path.join(patient_dest_folder, patient_name + - "_t2.nii.gz")) + nib.save( + resized_t2_image, os.path.join(patient_dest_folder, patient_name + "_t2.nii.gz") + ) t1ce_image = nib.load(t1ce) resized_t1ce_image = preprocess_image(t1ce_image, is_mask=False) temp_affine = t1ce_image.affine temp_affine[:3, :3] = new_affine - resized_t1ce_image = nib.Nifti1Image(resized_t1ce_image, - t1ce_image.affine) - nib.save(resized_t1ce_image, os.path.join(patient_dest_folder, patient_name + - "_t1ce.nii.gz")) + resized_t1ce_image = nib.Nifti1Image(resized_t1ce_image, t1ce_image.affine) + nib.save( + resized_t1ce_image, + os.path.join(patient_dest_folder, patient_name + "_t1ce.nii.gz"), + ) flair_image = nib.load(flair) resized_flair_image = preprocess_image(flair_image, is_mask=False) temp_affine = flair_image.affine temp_affine[:3, :3] = new_affine - resized_flair_image = nib.Nifti1Image(resized_flair_image, - flair_image.affine) - nib.save(resized_flair_image, os.path.join(patient_dest_folder, patient_name - + - "_flair.nii.gz")) + resized_flair_image = nib.Nifti1Image(resized_flair_image, flair_image.affine) + nib.save( + resized_flair_image, + os.path.join(patient_dest_folder, patient_name + "_flair.nii.gz"), + ) if not test: gt_image = nib.load(gt) resized_gt_image = preprocess_image(gt_image, is_mask=True) - resized_gt_image = nib.Nifti1Image(resized_gt_image, - gt_image.affine) - nib.save(resized_gt_image, os.path.join(patient_dest_folder, patient_name - +"_mask.nii.gz")) + resized_gt_image = nib.Nifti1Image(resized_gt_image, gt_image.affine) + nib.save( + resized_gt_image, + os.path.join(patient_dest_folder, patient_name + "_mask.nii.gz"), + ) return def batch_works(k): if k == n_processes - 1: - sub_patients = patients[k * int(len(patients) / n_processes):] + sub_patients = patients[k * int(len(patients) / n_processes) :] else: - sub_patients = patients[k * int(len(patients) / n_processes): - (k + 1) * int(len(patients) / n_processes)] + sub_patients = patients[ + k + * int(len(patients) / n_processes) : (k + 1) + * int(len(patients) / n_processes) + ] for patient in sub_patients: patient_name = os.path.basename(patient) print(patient_name) normalize(patient, output_path, patient_name) -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='intensity_standardize', formatter_class=argparse.RawTextHelpFormatter, - description='\nThis code was implemented to standardize intensities for skull stripping\n'+ '\n'\ - 'Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n'\ - 'For questions and feedback contact: software@cbica.upenn.edu') +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="intensity_standardize", + formatter_class=argparse.RawTextHelpFormatter, + description="\nThis code was implemented to standardize intensities for skull stripping\n" + + "\n" + "Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n" + "For questions and feedback contact: software@cbica.upenn.edu", + ) + + parser.add_argument( + "-i", + "--input_path", + dest="input_path", + help="input path for the tissues", + required=True, + ) + parser.add_argument( + "-o", + "--output_path", + dest="output_path", + help="output path for saving the files", + required=True, + ) + parser.add_argument( + "-t", + "--threads", + dest="threads", + help="number of threads, by default will use all", + ) + + parser.add_argument( + "-v", + "--version", + action="version", + version=pkg_resources.require("BrainMaGe")[0].version + + "\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.", + help="Show program's version number and exit.", + ) - parser.add_argument('-i', '--input_path', dest='input_path', - help="input path for the tissues", required=True) - parser.add_argument('-o', '--output_path', dest='output_path', - help="output path for saving the files", required=True) - parser.add_argument('-t', '--threads', dest='threads', - help="number of threads, by default will use all") - - parser.add_argument('-v', '--version', action='version', - version=pkg_resources.require("BrainMaGe")[0].version + '\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.', help="Show program's version number and exit.") - args = parser.parse_args() if args.threads: @@ -295,7 +395,7 @@ if __name__ == '__main__': input_path = os.path.abspath(args.input_path) output_path = os.path.abspath(args.output_path) os.makedirs(output_path, exist_ok=True) - patients = glob.glob(os.path.abspath(args.input_path)+'/*') + patients = glob.glob(os.path.abspath(args.input_path) + "/*") n_processes = cpu_count() pool = Pool(processes=n_processes) pool.map(batch_works, range(n_processes)) diff --git a/brain_mage_run b/brain_mage_run index d12fbdf..11b74ba 100755 --- a/brain_mage_run +++ b/brain_mage_run @@ -15,67 +15,111 @@ from BrainMaGe.tester import test_ma, test_multi_4 import pkg_resources -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='BrainMaGe', formatter_class=argparse.RawTextHelpFormatter, - description='\nThis code was implemented for Deep Learning '+\ - 'based training and inference of 3D-U-Net,\n3D-Res-U-Net models for '+\ - 'Brain Extraction a.k.a Skull Stripping in biomedical NIfTI volumes.\n'+\ - 'The project is hosted at: https://github.com/CBICA/BrainMaGe * \n'+\ - 'See the documentation for details on its use.\n'+\ - 'If you are using this tool, please cite out paper.' - 'This software accompanies the research presented in:\n'+\ - 'Thakur et al., \'Brain Extraction on MRI Scans in Presence of Diffuse\n'+\ - 'Glioma:Multi-institutional Performance Evaluation of Deep Learning Methods'+\ - 'and Robust Modality-Agnostic Training\'.\n'+\ - 'DOI: 10.1016/j.neuroimage.2020.117081\n' +\ - 'We hope our work helps you in your endeavours.\n'+ '\n'\ - 'Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n'\ - 'For questions and feedback contact: software@cbica.upenn.edu') - - parser.add_argument('-params', dest='params', type=str, - help='Specify the architecture of the model to be used, by providing a\n'+\ - 'config file [PARAMS_CFG]. A sample of the files is stored in\n'+\ - 'BrainMaGe/config folder for the train, test. Checkout the parameter\n'+\ - 'explanation in the Readme.md for more details.\n', - required=True) - - parser.add_argument('-train', dest='train', type=str, - help='Should be set to "True" (without the quotes) if you are trying to\n'+\ - 'run training, but make sure you intensity standardize the data \n'+\ - 'before attempting to train.\n', default='False') - - parser.add_argument('-test', dest='test', type=str, - help='Should be set to "False" (without the quotes) if you are trying\n'+\ - 'to train a new model, do not set the training to true as testing\n'+\ - 'will be overridden.\n', default='True') - - parser.add_argument('-dev', default='0', dest='device', type=str, - help='used to set on which device the prediction will run.\n'+ - 'Must be either int or str. Use int for GPU id or\n'+ - '\'cpu\' to run on CPU. Avoid training on CPU. \n'+ - 'Default for selecting first GPU is set to -dev 0\n', - required=False) - - parser.add_argument('-mode', dest='mode', type=str, - help='Should be one of "MA" or "Multi-4" without the quotes so that \n'+ - 'the appropriate weight files are loaded automatically during\n'+\ - 'the test time.') - - parser.add_argument('-save_brain', default=1, type=int, required=False, dest='save_brain', - help='if set to 0 the segmentation mask will be only produced and\n'+\ - 'and the mask will not be applied on the input image to produce\n'+\ - ' a brain. This step is to be only applied if you trust this\n'+\ - 'software and do not feel the need for Manual QC. This will save\n'+\ - ' you some time. This is useless for training though.') - - parser.add_argument('-load', default=None, dest='load', type=str, - help='If the location of the weight file is passed, the internal methods\n'+\ - 'are overridden to apply these weights to the model. We warn against\n'+\ - 'the usage of this unless you know what you are passing. C') - - parser.add_argument('-v', '--version', action='version', - version=pkg_resources.require("BrainMaGe")[0].version + '\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.', help="Show program's version number and exit.") - +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="BrainMaGe", + formatter_class=argparse.RawTextHelpFormatter, + description="\nThis code was implemented for Deep Learning " + + "based training and inference of 3D-U-Net,\n3D-Res-U-Net models for " + + "Brain Extraction a.k.a Skull Stripping in biomedical NIfTI volumes.\n" + + "The project is hosted at: https://github.com/CBICA/BrainMaGe * \n" + + "See the documentation for details on its use.\n" + + "If you are using this tool, please cite out paper." + "This software accompanies the research presented in:\n" + + "Thakur et al., 'Brain Extraction on MRI Scans in Presence of Diffuse\n" + + "Glioma:Multi-institutional Performance Evaluation of Deep Learning Methods" + + "and Robust Modality-Agnostic Training'.\n" + + "DOI: 10.1016/j.neuroimage.2020.117081\n" + + "We hope our work helps you in your endeavours.\n" + + "\n" + "Copyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.\n" + "For questions and feedback contact: software@cbica.upenn.edu", + ) + + parser.add_argument( + "-params", + dest="params", + type=str, + help="Specify the architecture of the model to be used, by providing a\n" + + "config file [PARAMS_CFG]. A sample of the files is stored in\n" + + "BrainMaGe/config folder for the train, test. Checkout the parameter\n" + + "explanation in the Readme.md for more details.\n", + required=True, + ) + + parser.add_argument( + "-train", + dest="train", + type=str, + help='Should be set to "True" (without the quotes) if you are trying to\n' + + "run training, but make sure you intensity standardize the data \n" + + "before attempting to train.\n", + default="False", + ) + + parser.add_argument( + "-test", + dest="test", + type=str, + help='Should be set to "False" (without the quotes) if you are trying\n' + + "to train a new model, do not set the training to true as testing\n" + + "will be overridden.\n", + default="True", + ) + + parser.add_argument( + "-dev", + default="0", + dest="device", + type=str, + help="used to set on which device the prediction will run.\n" + + "Must be either int or str. Use int for GPU id or\n" + + "'cpu' to run on CPU. Avoid training on CPU. \n" + + "Default for selecting first GPU is set to -dev 0\n", + required=False, + ) + + parser.add_argument( + "-mode", + dest="mode", + type=str, + help='Should be one of "MA" or "Multi-4" without the quotes so that \n' + + "the appropriate weight files are loaded automatically during\n" + + "the test time.", + ) + + parser.add_argument( + "-save_brain", + default=1, + type=int, + required=False, + dest="save_brain", + help="if set to 0 the segmentation mask will be only produced and\n" + + "and the mask will not be applied on the input image to produce\n" + + " a brain. This step is to be only applied if you trust this\n" + + "software and do not feel the need for Manual QC. This will save\n" + + " you some time. This is useless for training though.", + ) + + parser.add_argument( + "-load", + default=None, + dest="load", + type=str, + help="If the location of the weight file is passed, the internal methods\n" + + "are overridden to apply these weights to the model. We warn against\n" + + "the usage of this unless you know what you are passing. C", + ) + + parser.add_argument( + "-v", + "--version", + action="version", + version=pkg_resources.require("BrainMaGe")[0].version + + "\n\nCopyright: Center for Biomedical Image Computing and Analytics (CBICA), University of Pennsylvania.", + help="Show program's version number and exit.", + ) + args = parser.parse_args() params_file = os.path.abspath(args.params) @@ -86,10 +130,10 @@ if __name__ == '__main__': # some sanity checking if args.train == args.test: - raise ValueError('Please enable either testing or training modes, not both') + raise ValueError("Please enable either testing or training modes, not both") if args.train == False and args.test == False: - raise ValueError('One of the options needs to be enabled.') - + raise ValueError("One of the options needs to be enabled.") + # If weights are given in params, then set weights to given params # else set weights to None if args.load is not None: @@ -102,41 +146,55 @@ if __name__ == '__main__': # Else raise value error if weights is not None: if os.path.exists(weights): - if args.train == 'True': + if args.train == "True": _, ext = os.path.splitext(weights) - if ext != '.ckpt': - raise ValueError("The extension was not a .ckpt file for training to enable proper\n"+\ - "resume during training. Please pass a .ckpt file.") - elif args.test == 'True': + if ext != ".ckpt": + raise ValueError( + "The extension was not a .ckpt file for training to enable proper\n" + + "resume during training. Please pass a .ckpt file." + ) + elif args.test == "True": print(args.mode) - if args.mode.lower() == 'ma' or args.mode.lower() == 'multi_4' or args.mode.lower() == 'bids': + if ( + args.mode.lower() == "ma" + or args.mode.lower() == "multi_4" + or args.mode.lower() == "bids" + ): _, ext = os.path.splitext(weights) - if ext != '.pt': - raise ValueError("Expected a .pt file, got a file with %s extension. If it is a\n"+\ - ".ckpt file, please conver it with our converion script\n"+\ - "mentioned in the Readme.md") - else : - raise ValueError('Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', - 'We received : ', args.mode, - 'Common mistakes include spelling mistakes, check it to make sure.') + if ext != ".pt": + raise ValueError( + "Expected a .pt file, got a file with %s extension. If it is a\n" + + ".ckpt file, please conver it with our converion script\n" + + "mentioned in the Readme.md" + ) + else: + raise ValueError( + 'Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', + "We received : ", + args.mode, + "Common mistakes include spelling mistakes, check it to make sure.", + ) else: - if args.train == 'True': + if args.train == "True": pass - elif args.test == 'True': + elif args.test == "True": base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - base_dir = os.path.join(os.path.dirname(base_dir), 'BrainMaGe/weights') - if args.mode.lower() == 'ma' or args.mode.lower() == 'bids': - weights = os.path.join(base_dir, 'resunet_ma.pt') - elif args.mode.lower() == 'multi-4': - weights = os.path.join(base_dir, 'resunet_multi_4.pt') + base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights") + if args.mode.lower() == "ma" or args.mode.lower() == "bids": + weights = os.path.join(base_dir, "resunet_ma.pt") + elif args.mode.lower() == "multi-4": + weights = os.path.join(base_dir, "resunet_multi_4.pt") else: - raise ValueError('Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', - 'We received : ', args.mode, - 'Common mistakes include spelling mistakes, check it to make sure.') + raise ValueError( + 'Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', + "We received : ", + args.mode, + "Common mistakes include spelling mistakes, check it to make sure.", + ) print("Weight file used :", weights) print(__file__) - if DEVICE == 'cpu': + if DEVICE == "cpu": pass else: DEVICE = int(DEVICE) @@ -146,20 +204,24 @@ if __name__ == '__main__': elif args.save_brain == 1: args.save_brain = True else: - raise ValueError('Unknown value for save brain:') - - if args.train == 'True': + raise ValueError("Unknown value for save brain:") + + if args.train == "True": trainer_main.train_network(params_file, DEVICE, weights) - elif args.test == 'True': - if args.mode.lower() == 'ma' or args.mode.lower() == 'bids': + elif args.test == "True": + if args.mode.lower() == "ma" or args.mode.lower() == "bids": test_ma.infer_ma(params_file, DEVICE, args.save_brain, weights) - elif args.mode.lower() == 'multi-4': + elif args.mode.lower() == "multi-4": test_multi_4.infer_multi_4(params_file, DEVICE, args.save_brain, weights) - else : - raise ValueError('Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', - 'We received : ', args.mode, - 'Common mistakes include spelling mistakes, check it to make sure.') + else: + raise ValueError( + 'Unknown value for mode. Expected one of "MA" or "Multi-4" without the quotes.', + "We received : ", + args.mode, + "Common mistakes include spelling mistakes, check it to make sure.", + ) else: - raise ValueError("Expected the modes to be set with either -train True or -test True.\n"+\ - "Please try again!") - + raise ValueError( + "Expected the modes to be set with either -train True or -test True.\n" + + "Please try again!" + ) diff --git a/brain_mage_single_run b/brain_mage_single_run index f58fc1f..9a9cf17 100755 --- a/brain_mage_single_run +++ b/brain_mage_single_run @@ -101,9 +101,10 @@ if __name__ == "__main__": print("Weight file used :", weights) - # Running Inference - test_single_run.infer_single_ma(input_path, output_path, weights, mask_path, device=DEVICE) + test_single_run.infer_single_ma( + input_path, output_path, weights, mask_path, device=DEVICE + ) print("*" * 80) print("Finished") diff --git a/setup.py b/setup.py index 5b73743..4d13844 100755 --- a/setup.py +++ b/setup.py @@ -10,29 +10,34 @@ from setuptools import setup import setuptools -setup(name='BrainMaGe', - version='1.0.1', - description='Skull stripping using multiple and single modalities', - url='https://github.com/CBICA/BrainMaGe', - python_requires='>=3.6', - author='Siddhesh Thakur', - author_email='software@cbica.upenn.edu', - license='BSD-3-Clause', - zip_safe=False, - install_requires=[ - 'numpy', - 'torch>=1.5.1', - 'scikit-image', - 'nibabel', - 'pytorch-lightning==0.8.1' - ], - scripts=['brain_mage_run', 'brain_mage_single_run', 'brain_mage_intensity_standardize'], - classifiers=[ - 'Intended Audience :: Science/Research', - 'Programming Language :: Python', - 'Topic :: Scientific/Engineering', - 'Operating System :: Unix' - ], - packages=setuptools.find_packages(), - include_package_data=True - ) +setup( + name="BrainMaGe", + version="1.0.1", + description="Skull stripping using multiple and single modalities", + url="https://github.com/CBICA/BrainMaGe", + python_requires=">=3.6", + author="Siddhesh Thakur", + author_email="software@cbica.upenn.edu", + license="BSD-3-Clause", + zip_safe=False, + install_requires=[ + "numpy", + "torch>=1.5.1", + "scikit-image", + "nibabel", + "pytorch-lightning==0.8.1", + ], + scripts=[ + "brain_mage_run", + "brain_mage_single_run", + "brain_mage_intensity_standardize", + ], + classifiers=[ + "Intended Audience :: Science/Research", + "Programming Language :: Python", + "Topic :: Scientific/Engineering", + "Operating System :: Unix", + ], + packages=setuptools.find_packages(), + include_package_data=True, +) From 7114b3212408d5c55fae9140ed9e04c79bef2134 Mon Sep 17 00:00:00 2001 From: Siddhesh Thakur Date: Thu, 29 Apr 2021 04:41:40 +0530 Subject: [PATCH 03/13] Update README.md --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 5fc8303..f82f0e8 100755 --- a/README.md +++ b/README.md @@ -75,6 +75,19 @@ python setup.py install # install dependencies and BrainMaGe - ```$mode``` can be ```MA``` for modality agnostic or ```Multi-4```. - ```$device``` refers to the GPU device where you want your code to run or the CPU. +### Steps to run application (Alternative) + +1.Although this method is much slower, and runs for single subject at a time, it works flawlessly on CPU's and GPU's. + + conda activate brainmage + brain_mage_single_run -i $path_to_input.nii.gz -o $path_to_output_mask.nii.gz + \ -m $path_to_output_brain.nii.gz -dev $device + + Where: + - `$path_to_input.nii.gz` can be path to the input file as a nifti. + - `$path_to_output_mask.nii.gz` is the output path to save the mask for the nifti + - `$path_to_output_brain.nii.gz` is the output path to brain for the nifti + ## [ADVANCED] Train your own model 1. Co-registration within patient in a common atlas space such as the [SRI-24 atlas](https://www.nitrc.org/projects/sri24/) in the LPS/RAI space. From 94e4b140595a5263387ca1f651f697cae55db433 Mon Sep 17 00:00:00 2001 From: Siddhesh Thakur Date: Thu, 29 Apr 2021 14:02:34 +0530 Subject: [PATCH 04/13] added CCA for single_run --- BrainMaGe/tester/test_single_run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index d312e38..522a56c 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -72,6 +72,7 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp to_save = interpolate_image(output, patient_nib.shape) to_save[to_save >= 0.9] = 1 to_save[to_save < 0.9] = 0 + to_save = postprocess_prediction(to_save) to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine) nib.save(to_save_nib, os.path.join(output_path)) From ae41d7d6fc3f8b94318af954e3cff893987b48fd Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 09:45:53 -0400 Subject: [PATCH 05/13] added bids as dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4d13844..d905cec 100755 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ "scikit-image", "nibabel", "pytorch-lightning==0.8.1", + "bids" ], scripts=[ "brain_mage_run", From 3cf07d48e63f12d15d7d6fe02e941390813177fa Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 09:46:03 -0400 Subject: [PATCH 06/13] corrected the location --- brain_mage_single_run | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brain_mage_single_run b/brain_mage_single_run index 9a9cf17..2ad756e 100755 --- a/brain_mage_single_run +++ b/brain_mage_single_run @@ -95,8 +95,8 @@ if __name__ == "__main__": mask_path = args.mask_path DEVICE = args.device - base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights") + base_dir = os.path.dirname(os.path.realpath(__file__)) + base_dir = os.path.join(base_dir, "BrainMaGe/weights") weights = os.path.join(base_dir, "resunet_ma.pt") print("Weight file used :", weights) From 5aa61ace43bbca1840496b512f68776a238a6b49 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 09:55:23 -0400 Subject: [PATCH 07/13] updated ignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4dd072b..81ef4a3 100755 --- a/.gitignore +++ b/.gitignore @@ -83,7 +83,7 @@ celerybeat-schedule .env # virtualenv -venv/ +venv*/ ENV/ # Spyder project settings From 75417db1fa33ed82ab65f21fa865d60ebb7bc578 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 10:07:02 -0400 Subject: [PATCH 08/13] putting cpu as default --- brain_mage_single_run | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/brain_mage_single_run b/brain_mage_single_run index 2ad756e..82e07ef 100755 --- a/brain_mage_single_run +++ b/brain_mage_single_run @@ -69,13 +69,12 @@ if __name__ == "__main__": parser.add_argument( "-dev", - default="0", + default="cpu", dest="device", type=str, help="used to set on which device the prediction will run.\n" + "Must be either int or str. Use int for GPU id or\n" - + "'cpu' to run on CPU. Avoid training on CPU. \n" - + "Default for selecting first GPU is set to -dev 0\n", + + "'cpu' to run on CPU. Avoid training on CPU. \n", required=False, ) From e6ee1feca17c9f53427ff31ef1e073c9b1cf0ec0 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 10:31:29 -0400 Subject: [PATCH 09/13] now stuff is working --- BrainMaGe/tester/test_single_run.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index 522a56c..22b14ed 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -48,7 +48,8 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp model = fetch_model( modelname="resunet", num_channels=1, num_classes=2, num_filters=16 ) - checkpoint = torch.load(weights) + + checkpoint = torch.load(weights, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state_dict"]) if device != "cpu": @@ -56,9 +57,9 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp model.eval() patient_nib = nib.load(input_path) - image = patient_nib.get_fdata() + image_data = patient_nib.get_fdata() old_shape = patient_nib.shape - image = process_image(image) + image = process_image(image_data) image = resize( image, (128, 128, 128), order=3, mode="edge", cval=0, anti_aliasing=False ) @@ -82,7 +83,7 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp print("You chose to save the brain. We are now saving it with the masks.") brain_data = image_data brain_data[to_save == 0] = 0 - to_save_brain = nib.Nifti1Image(brain_data, image.affine) + to_save_brain = nib.Nifti1Image(brain_data, patient_nib.affine) nib.save(to_save_brain, os.path.join(mask_path)) print("Thank you for using BrainMaGe") From 73a8e6a1f2824c10f0d252b9149f3e63e56644e9 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 10:34:22 -0400 Subject: [PATCH 10/13] reduced a variable --- BrainMaGe/tester/test_single_run.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index 22b14ed..eedd398 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -81,9 +81,8 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp if mask_path is not None: print("You chose to save the brain. We are now saving it with the masks.") - brain_data = image_data - brain_data[to_save == 0] = 0 - to_save_brain = nib.Nifti1Image(brain_data, patient_nib.affine) + image_data[to_save == 0] = 0 + to_save_brain = nib.Nifti1Image(image_data, patient_nib.affine) nib.save(to_save_brain, os.path.join(mask_path)) print("Thank you for using BrainMaGe") From 5b589a7ca06d801b912e0097dbaaf1b40f7cc517 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 11:05:51 -0400 Subject: [PATCH 11/13] added weight file parsing and an sanity check before processing begins --- brain_mage_single_run | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/brain_mage_single_run b/brain_mage_single_run index 82e07ef..29fa488 100755 --- a/brain_mage_single_run +++ b/brain_mage_single_run @@ -8,7 +8,7 @@ Created on Sat May 30 01:05:59 2020 from __future__ import absolute_import, print_function, division import argparse -import os +import os, sys from BrainMaGe.tester import test_single_run import pkg_resources @@ -95,10 +95,20 @@ if __name__ == "__main__": DEVICE = args.device base_dir = os.path.dirname(os.path.realpath(__file__)) - base_dir = os.path.join(base_dir, "BrainMaGe/weights") - weights = os.path.join(base_dir, "resunet_ma.pt") - - print("Weight file used :", weights) + base_dir = os.path.join(base_dir, "BrainMaGe", "weights") + if os.path.isdir(base_dir): + weights = os.path.join(base_dir, "resunet_ma.pt") + else: + # this control path is needed if someone installs brainmage into their virtual environment directly + base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights") + if os.path.isdir(base_dir): + weights = os.path.join(base_dir, "resunet_ma.pt") + + if os.path.isfile(weights): + print("Weight file used :", weights) + else: + sys.exit('Weights file at \'' + weights + '\' was not found...') # Running Inference test_single_run.infer_single_ma( From 54886b66708fad032a9b22b814aeb9e12109ee66 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 11:06:23 -0400 Subject: [PATCH 12/13] initializing new "to_write" objects --- BrainMaGe/tester/test_single_run.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/BrainMaGe/tester/test_single_run.py b/BrainMaGe/tester/test_single_run.py index eedd398..ac11d74 100755 --- a/BrainMaGe/tester/test_single_run.py +++ b/BrainMaGe/tester/test_single_run.py @@ -81,8 +81,10 @@ def infer_single_ma(input_path, output_path, weights, mask_path=None, device="cp if mask_path is not None: print("You chose to save the brain. We are now saving it with the masks.") - image_data[to_save == 0] = 0 - to_save_brain = nib.Nifti1Image(image_data, patient_nib.affine) + patient_nib_write = nib.load(input_path) + image_data_write = patient_nib_write.get_fdata() + image_data_write[to_save == 0] = 0 + to_save_brain = nib.Nifti1Image(image_data_write, patient_nib_write.affine) nib.save(to_save_brain, os.path.join(mask_path)) print("Thank you for using BrainMaGe") From e582cd773b3dc6ce2db6d7b584cd74a70367e3c5 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 29 Apr 2021 11:06:41 -0400 Subject: [PATCH 13/13] trying to get input to not change - failed --- BrainMaGe/utils/utils_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/BrainMaGe/utils/utils_test.py b/BrainMaGe/utils/utils_test.py index 8a8de1e..f6f07d9 100755 --- a/BrainMaGe/utils/utils_test.py +++ b/BrainMaGe/utils/utils_test.py @@ -77,12 +77,13 @@ def process_image(image): DESCRIPTION. """ + to_return = image new_image_temp = image[image >= image.mean()] p1 = np.percentile(new_image_temp, 2) p2 = np.percentile(new_image_temp, 95) - image[image > p2] = p2 - image = (image - p1) / p2 - return image + to_return[to_return > p2] = p2 + to_return = (to_return - p1) / p2 + return to_return def padder_and_cropper(image, pad_info):