diff --git a/nipreps/synthstrip/__main__.py b/nipreps/synthstrip/__main__.py index 6163f49..370884c 100644 --- a/nipreps/synthstrip/__main__.py +++ b/nipreps/synthstrip/__main__.py @@ -21,12 +21,14 @@ # https://www.nipreps.org/community/licensing/ # -if __name__ == "__main__": +if __name__ == '__main__': import sys + from nipreps.synthstrip.cli import main + from . import __name__ as module # `python -m ` typically displays the command as __main__.py - if "__main__.py" in sys.argv[0]: - sys.argv[0] = "%s -m %s" % (sys.executable, module) + if '__main__.py' in sys.argv[0]: + sys.argv[0] = f'{sys.executable} -m {module}' main() diff --git a/nipreps/synthstrip/cli.py b/nipreps/synthstrip/cli.py index a4681c0..76e2159 100644 --- a/nipreps/synthstrip/cli.py +++ b/nipreps/synthstrip/cli.py @@ -49,43 +49,45 @@ def main(): """Entry point to SynthStrip.""" import os from argparse import ArgumentParser + + import nibabel as nb import numpy as np import scipy - import nibabel as nb import torch + from .model import StripModel # parse command line parser = ArgumentParser(description=__doc__) parser.add_argument( - "-i", - "--image", - metavar="file", + '-i', + '--image', + metavar='file', required=True, - help="Input image to skullstrip.", + help='Input image to skullstrip.', ) parser.add_argument( - "-o", "--out", metavar="file", help="Save stripped image to path." + '-o', '--out', metavar='file', help='Save stripped image to path.' ) parser.add_argument( - "-m", "--mask", metavar="file", help="Save binary brain mask to path." + '-m', '--mask', metavar='file', help='Save binary brain mask to path.' ) - parser.add_argument("-g", "--gpu", action="store_true", help="Use the GPU.") + parser.add_argument('-g', '--gpu', action='store_true', help='Use the GPU.') parser.add_argument( - "-n", "--num-threads", action="store", type=int, help="number of threads") + '-n', '--num-threads', action='store', type=int, help='number of threads') parser.add_argument( - "-b", - "--border", + '-b', + '--border', default=1, type=int, - help="Mask border threshold in mm. Default is 1.", + help='Mask border threshold in mm. Default is 1.', ) - parser.add_argument("--model", metavar="file", help="Alternative model weights.") + parser.add_argument('--model', metavar='file', help='Alternative model weights.') args = parser.parse_args() # sanity check on the inputs if not args.out and not args.mask: - parser.fatal("Must provide at least --out or --mask output flags.") + parser.fatal('Must provide at least --out or --mask output flags.') # necessary for speed gains (I think) torch.backends.cudnn.benchmark = True @@ -93,19 +95,19 @@ def main(): # configure GPU device if args.gpu: - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - device = torch.device("cuda") - device_name = "GPU" + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + device = torch.device('cuda') + device_name = 'GPU' else: - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - device = torch.device("cpu") - device_name = "CPU" + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + device = torch.device('cpu') + device_name = 'CPU' if args.num_threads and args.num_threads > 0: torch.set_num_threads(args.num_threads) # configure model - print(f"Configuring model on the {device_name}") + print(f'Configuring model on the {device_name}') with torch.no_grad(): model = StripModel() @@ -115,20 +117,20 @@ def main(): # load model weights if args.model is not None: modelfile = args.model - print("Using custom model weights") + print('Using custom model weights') else: - raise RuntimeError("A model must be provided.") + raise RuntimeError('A model must be provided.') checkpoint = torch.load(modelfile, map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + model.load_state_dict(checkpoint['model_state_dict']) # load input volume - print(f"Input image read from: {args.image}") + print(f'Input image read from: {args.image}') # normalize intensities image = nb.load(args.image) conformed = conform(image) - in_data = conformed.get_fdata(dtype="float32") + in_data = conformed.get_fdata(dtype='float32') in_data -= in_data.min() in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1) in_data = in_data[np.newaxis, np.newaxis] @@ -142,10 +144,10 @@ def main(): sdt_target = resample_like( nb.Nifti1Image(sdt, conformed.affine, None), image, - output_dtype="int16", + output_dtype='int16', cval=100, ) - sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16") + sdt_data = np.asanyarray(sdt_target.dataobj).astype('int16') # find largest CC (just do this to be safe for now) components = scipy.ndimage.label(sdt_data.squeeze() < args.border)[0] @@ -161,25 +163,25 @@ def main(): nb.Nifti1Image(img_data, image.affine, image.header).to_filename( args.out, ) - print(f"Masked image saved to: {args.out}") + print(f'Masked image saved to: {args.out}') # write the brain mask if args.mask: hdr = image.header.copy() - hdr.set_data_dtype("uint8") + hdr.set_data_dtype('uint8') nb.Nifti1Image(mask, image.affine, hdr).to_filename(args.mask) - print(f"Binary brain mask saved to: {args.mask}") + print(f'Binary brain mask saved to: {args.mask}') - print("If you use SynthStrip in your analysis, please cite:") - print("----------------------------------------------------") - print("SynthStrip: Skull-Stripping for Any Brain Image.") - print("A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.") + print('If you use SynthStrip in your analysis, please cite:') + print('----------------------------------------------------') + print('SynthStrip: Skull-Stripping for Any Brain Image.') + print('A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.') def conform(input_nii): """Resample image as SynthStrip likes it.""" - import numpy as np import nibabel as nb + import numpy as np from nitransforms.linear import Affine shape = np.array(input_nii.shape[:3]) @@ -239,5 +241,5 @@ def resample_like(image, target, output_dtype=None, cval=0): return Affine(reference=target).apply(image, output_dtype=output_dtype, cval=cval) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/nipreps/synthstrip/model.py b/nipreps/synthstrip/model.py index 075a56f..f7a43d4 100644 --- a/nipreps/synthstrip/model.py +++ b/nipreps/synthstrip/model.py @@ -44,9 +44,9 @@ """ +import numpy as np import torch import torch.nn as nn -import numpy as np class StripModel(nn.Module): @@ -70,7 +70,7 @@ def __init__( if isinstance(nb_features, int): if nb_levels is None: raise ValueError( - "must provide unet nb_levels if nb_features is an integer" + 'must provide unet nb_levels if nb_features is an integer' ) feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype( int @@ -81,7 +81,7 @@ def __init__( np.repeat(np.flip(feats), nb_conv_per_level), ] elif nb_levels is not None: - raise ValueError("cannot use nb_levels if nb_features is not an integer") + raise ValueError('cannot use nb_levels if nb_features is not an integer') # extract any surplus (full resolution) decoder convolutions enc_nf, dec_nf = nb_features @@ -94,10 +94,10 @@ def __init__( max_pool = [max_pool] * self.nb_levels # cache downsampling / upsampling operations - MaxPooling = getattr(nn, "MaxPool%dd" % ndims) + MaxPooling = getattr(nn, 'MaxPool%dd' % ndims) self.pooling = [MaxPooling(s) for s in max_pool] self.upsampling = [ - nn.Upsample(scale_factor=s, mode="nearest") for s in max_pool + nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool ] # configure encoder (down-sampling path) @@ -128,7 +128,7 @@ def __init__( # now we take care of any remaining convolutions self.remaining = nn.ModuleList() - for num, nf in enumerate(final_convs): + for nf in final_convs: self.remaining.append(ConvBlock(ndims, prev_nf, nf)) prev_nf = nf @@ -169,17 +169,17 @@ class ConvBlock(nn.Module): Specific convolutional block followed by leakyrelu for unet. """ - def __init__(self, ndims, in_channels, out_channels, stride=1, activation="leaky"): + def __init__(self, ndims, in_channels, out_channels, stride=1, activation='leaky'): super().__init__() - Conv = getattr(nn, "Conv%dd" % ndims) + Conv = getattr(nn, 'Conv%dd' % ndims) self.conv = Conv(in_channels, out_channels, 3, stride, 1) - if activation == "leaky": + if activation == 'leaky': self.activation = nn.LeakyReLU(0.2) elif activation is None: self.activation = None else: - raise ValueError(f"Unknown activation: {activation}") + raise ValueError(f'Unknown activation: {activation}') def forward(self, x): out = self.conv(x) diff --git a/nipreps/synthstrip/wrappers/nipype.py b/nipreps/synthstrip/wrappers/nipype.py index 0843d12..4cdcca8 100644 --- a/nipreps/synthstrip/wrappers/nipype.py +++ b/nipreps/synthstrip/wrappers/nipype.py @@ -23,17 +23,18 @@ """SynthStrip interface.""" import os from pathlib import Path + from nipype.interfaces.base import ( CommandLine, CommandLineInputSpec, File, TraitedSpec, - traits, Undefined, + traits, ) -_fs_home = os.getenv("FREESURFER_HOME", None) -_default_model_path = Path(_fs_home) / "models" / "synthstrip.1.pt" if _fs_home else Undefined +_fs_home = os.getenv('FREESURFER_HOME', None) +_default_model_path = Path(_fs_home) / 'models' / 'synthstrip.1.pt' if _fs_home else Undefined if _fs_home and not _default_model_path.exists(): _default_model_path = Undefined @@ -43,43 +44,43 @@ class _SynthStripInputSpec(CommandLineInputSpec): in_file = File( exists=True, mandatory=True, - argstr="-i %s", - desc="Input image to be brain extracted", + argstr='-i %s', + desc='Input image to be brain extracted', ) use_gpu = traits.Bool( - False, usedefault=True, argstr="-g", desc="Use GPU", nohash=True + False, usedefault=True, argstr='-g', desc='Use GPU', nohash=True ) model = File( str(_default_model_path), usedefault=True, exists=True, - argstr="--model %s", + argstr='--model %s', desc="file containing model's weights", ) border_mm = traits.Int( - 1, usedefault=True, argstr="-b %d", desc="Mask border threshold in mm" + 1, usedefault=True, argstr='-b %d', desc='Mask border threshold in mm' ) out_file = File( - name_source=["in_file"], - name_template="%s_desc-brain.nii.gz", - argstr="-o %s", - desc="store brain-extracted input to file", + name_source=['in_file'], + name_template='%s_desc-brain.nii.gz', + argstr='-o %s', + desc='store brain-extracted input to file', ) out_mask = File( - name_source=["in_file"], - name_template="%s_desc-brain_mask.nii.gz", - argstr="-m %s", - desc="store brainmask to file", + name_source=['in_file'], + name_template='%s_desc-brain_mask.nii.gz', + argstr='-m %s', + desc='store brainmask to file', ) - num_threads = traits.Int(desc="Number of threads", argstr="-n %d", nohash=True) + num_threads = traits.Int(desc='Number of threads', argstr='-n %d', nohash=True) class _SynthStripOutputSpec(TraitedSpec): - out_file = File(desc="brain-extracted image") - out_mask = File(desc="brain mask") + out_file = File(desc='brain-extracted image') + out_mask = File(desc='brain mask') class SynthStrip(CommandLine): - _cmd = "nipreps-synthstrip" + _cmd = 'nipreps-synthstrip' input_spec = _SynthStripInputSpec output_spec = _SynthStripOutputSpec diff --git a/nipreps/synthstrip/wrappers/pydra.py b/nipreps/synthstrip/wrappers/pydra.py index 2109186..eeff4e8 100644 --- a/nipreps/synthstrip/wrappers/pydra.py +++ b/nipreps/synthstrip/wrappers/pydra.py @@ -23,12 +23,13 @@ """SynthStrip interface.""" import os -import attr from pathlib import Path + +import attr import pydra -_fs_home = os.getenv("FREESURFER_HOME", None) -_default_model_path = Path(_fs_home) / "models" / "synthstrip.1.pt" if _fs_home else None +_fs_home = os.getenv('FREESURFER_HOME', None) +_default_model_path = Path(_fs_home) / 'models' / 'synthstrip.1.pt' if _fs_home else None if _fs_home and not _default_model_path.exists(): _default_model_path = None @@ -41,7 +42,7 @@ attr.ib( type=str, metadata={ - 'argstr': "-i", + 'argstr': '-i', 'help_string': 'Input image to skullstrip', 'mandatory': True, }, @@ -51,18 +52,18 @@ 'out_file', str, { - 'argstr': "-o", - "help_string": "Save stripped image to path", - "output_file_template": "{in_file}_desc-brain.nii.gz", + 'argstr': '-o', + 'help_string': 'Save stripped image to path', + 'output_file_template': '{in_file}_desc-brain.nii.gz', }, ), ( 'out_mask', str, { - 'argstr': "-m", - "help_string": "Save binary brain mask to path", - "output_file_template": "{in_file}_desc-brain_mask.nii.gz", + 'argstr': '-m', + 'help_string': 'Save binary brain mask to path', + 'output_file_template': '{in_file}_desc-brain_mask.nii.gz', }, ), ( @@ -70,7 +71,7 @@ bool, False, { - 'argstr': "-g", + 'argstr': '-g', 'help_string': 'Use the GPU', }, ), @@ -79,8 +80,8 @@ int, 1, { - 'argstr': "-b", - "help_string": "Mask border threshold in mm", + 'argstr': '-b', + 'help_string': 'Mask border threshold in mm', }, ), ( @@ -88,7 +89,7 @@ bool, False, { - 'argstr': "--no-csf", + 'argstr': '--no-csf', 'help_string': 'Exclude CSF from brain border', }, ), @@ -97,16 +98,16 @@ pydra.specs.File, str(_default_model_path), { - 'argstr': "--model", - "help_string": "File containing model's weights", + 'argstr': '--model', + 'help_string': "File containing model's weights", }, ), ( 'num_threads', int, { - 'argstr': "-n", - "help_string": "Number of threads", + 'argstr': '-n', + 'help_string': 'Number of threads', }, ), ], @@ -114,6 +115,6 @@ ) SynthStrip = pydra.ShellCommandTask( - executable="nipreps-synthstrip", + executable='nipreps-synthstrip', input_spec = SynthStripInputSpec )