Skip to content

Commit

Permalink
sty: mass ruffication
Browse files Browse the repository at this point in the history
  • Loading branch information
mgxd committed Aug 29, 2024
1 parent 436c903 commit 82e80ac
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 115 deletions.
8 changes: 5 additions & 3 deletions nipreps/synthstrip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
# https://www.nipreps.org/community/licensing/
#

if __name__ == "__main__":
if __name__ == '__main__':
import sys

from nipreps.synthstrip.cli import main

from . import __name__ as module

# `python -m <module>` typically displays the command as __main__.py
if "__main__.py" in sys.argv[0]:
sys.argv[0] = "%s -m %s" % (sys.executable, module)
if '__main__.py' in sys.argv[0]:
sys.argv[0] = f'{sys.executable} -m {module}'
main()
90 changes: 41 additions & 49 deletions nipreps/synthstrip/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,63 +49,60 @@ def main():
"""Entry point to SynthStrip."""
import os
from argparse import ArgumentParser

import nibabel as nb
import numpy as np
import scipy
import nibabel as nb
import torch

from .model import StripModel

# parse command line
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-i",
"--image",
metavar="file",
'-i',
'--image',
metavar='file',
required=True,
help="Input image to skullstrip.",
help='Input image to skullstrip.',
)
parser.add_argument('-o', '--out', metavar='file', help='Save stripped image to path.')
parser.add_argument('-m', '--mask', metavar='file', help='Save binary brain mask to path.')
parser.add_argument('-g', '--gpu', action='store_true', help='Use the GPU.')
parser.add_argument('-n', '--num-threads', action='store', type=int, help='number of threads')
parser.add_argument(
"-o", "--out", metavar="file", help="Save stripped image to path."
)
parser.add_argument(
"-m", "--mask", metavar="file", help="Save binary brain mask to path."
)
parser.add_argument("-g", "--gpu", action="store_true", help="Use the GPU.")
parser.add_argument(
"-n", "--num-threads", action="store", type=int, help="number of threads")
parser.add_argument(
"-b",
"--border",
'-b',
'--border',
default=1,
type=int,
help="Mask border threshold in mm. Default is 1.",
help='Mask border threshold in mm. Default is 1.',
)
parser.add_argument("--model", metavar="file", help="Alternative model weights.")
parser.add_argument('--model', metavar='file', help='Alternative model weights.')
args = parser.parse_args()

# sanity check on the inputs
if not args.out and not args.mask:
parser.fatal("Must provide at least --out or --mask output flags.")
parser.fatal('Must provide at least --out or --mask output flags.')

# necessary for speed gains (I think)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# configure GPU device
if args.gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda")
device_name = "GPU"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda')
device_name = 'GPU'
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
device = torch.device("cpu")
device_name = "CPU"
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device = torch.device('cpu')
device_name = 'CPU'

if args.num_threads and args.num_threads > 0:
torch.set_num_threads(args.num_threads)

# configure model
print(f"Configuring model on the {device_name}")
print(f'Configuring model on the {device_name}')

with torch.no_grad():
model = StripModel()
Expand All @@ -115,20 +112,20 @@ def main():
# load model weights
if args.model is not None:
modelfile = args.model
print("Using custom model weights")
print('Using custom model weights')
else:
raise RuntimeError("A model must be provided.")
raise RuntimeError('A model must be provided.')

checkpoint = torch.load(modelfile, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(checkpoint['model_state_dict'])

# load input volume
print(f"Input image read from: {args.image}")
print(f'Input image read from: {args.image}')

# normalize intensities
image = nb.load(args.image)
conformed = conform(image)
in_data = conformed.get_fdata(dtype="float32")
in_data = conformed.get_fdata(dtype='float32')
in_data -= in_data.min()
in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)
in_data = in_data[np.newaxis, np.newaxis]
Expand All @@ -142,10 +139,10 @@ def main():
sdt_target = resample_like(
nb.Nifti1Image(sdt, conformed.affine, None),
image,
output_dtype="int16",
output_dtype='int16',
cval=100,
)
sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16")
sdt_data = np.asanyarray(sdt_target.dataobj).astype('int16')

# find largest CC (just do this to be safe for now)
components = scipy.ndimage.label(sdt_data.squeeze() < args.border)[0]
Expand All @@ -161,25 +158,25 @@ def main():
nb.Nifti1Image(img_data, image.affine, image.header).to_filename(
args.out,
)
print(f"Masked image saved to: {args.out}")
print(f'Masked image saved to: {args.out}')

# write the brain mask
if args.mask:
hdr = image.header.copy()
hdr.set_data_dtype("uint8")
hdr.set_data_dtype('uint8')
nb.Nifti1Image(mask, image.affine, hdr).to_filename(args.mask)
print(f"Binary brain mask saved to: {args.mask}")
print(f'Binary brain mask saved to: {args.mask}')

print("If you use SynthStrip in your analysis, please cite:")
print("----------------------------------------------------")
print("SynthStrip: Skull-Stripping for Any Brain Image.")
print("A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.")
print('If you use SynthStrip in your analysis, please cite:')
print('----------------------------------------------------')
print('SynthStrip: Skull-Stripping for Any Brain Image.')
print('A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.')


def conform(input_nii):
"""Resample image as SynthStrip likes it."""
import numpy as np
import nibabel as nb
import numpy as np
from nitransforms.linear import Affine

shape = np.array(input_nii.shape[:3])
Expand All @@ -199,10 +196,7 @@ def conform(input_nii):
)

# Get corner voxel centers in mm
corners_xyz = (
affine
@ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T
)
corners_xyz = affine @ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T

# Target affine is 1mm voxels in LIA orientation
target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)]
Expand All @@ -212,9 +206,7 @@ def conform(input_nii):
target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int)

# SynthStrip likes dimensions be multiple of 64 (192, 256, or 320)
target_shape = np.clip(
np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320
)
target_shape = np.clip(np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320)

# Ensure shape ordering is LIA too
target_shape[2], target_shape[1] = target_shape[1:3]
Expand All @@ -239,5 +231,5 @@ def resample_like(image, target, output_dtype=None, cval=0):
return Affine(reference=target).apply(image, output_dtype=output_dtype, cval=cval)


if __name__ == "__main__":
if __name__ == '__main__':
main()
30 changes: 11 additions & 19 deletions nipreps/synthstrip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
"""

import numpy as np
import torch
import torch.nn as nn
import numpy as np


class StripModel(nn.Module):
Expand All @@ -60,7 +60,6 @@ def __init__(
max_pool=2,
return_mask=False,
):

super().__init__()

# dimensionality
Expand All @@ -69,19 +68,15 @@ def __init__(
# build feature list automatically
if isinstance(nb_features, int):
if nb_levels is None:
raise ValueError(
"must provide unet nb_levels if nb_features is an integer"
)
feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(
int
)
raise ValueError('must provide unet nb_levels if nb_features is an integer')
feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
feats = np.clip(feats, 1, max_features)
nb_features = [
np.repeat(feats[:-1], nb_conv_per_level),
np.repeat(np.flip(feats), nb_conv_per_level),
]
elif nb_levels is not None:
raise ValueError("cannot use nb_levels if nb_features is not an integer")
raise ValueError('cannot use nb_levels if nb_features is not an integer')

# extract any surplus (full resolution) decoder convolutions
enc_nf, dec_nf = nb_features
Expand All @@ -94,11 +89,9 @@ def __init__(
max_pool = [max_pool] * self.nb_levels

# cache downsampling / upsampling operations
MaxPooling = getattr(nn, "MaxPool%dd" % ndims)
MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
self.pooling = [MaxPooling(s) for s in max_pool]
self.upsampling = [
nn.Upsample(scale_factor=s, mode="nearest") for s in max_pool
]
self.upsampling = [nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool]

# configure encoder (down-sampling path)
prev_nf = 1
Expand Down Expand Up @@ -128,7 +121,7 @@ def __init__(

# now we take care of any remaining convolutions
self.remaining = nn.ModuleList()
for num, nf in enumerate(final_convs):
for nf in final_convs:
self.remaining.append(ConvBlock(ndims, prev_nf, nf))
prev_nf = nf

Expand All @@ -140,7 +133,6 @@ def __init__(
self.remaining.append(ConvBlock(ndims, prev_nf, 1, activation=None))

def forward(self, x):

# encoder forward pass
x_history = [x]
for level, convs in enumerate(self.encoder):
Expand Down Expand Up @@ -169,17 +161,17 @@ class ConvBlock(nn.Module):
Specific convolutional block followed by leakyrelu for unet.
"""

def __init__(self, ndims, in_channels, out_channels, stride=1, activation="leaky"):
def __init__(self, ndims, in_channels, out_channels, stride=1, activation='leaky'):
super().__init__()

Conv = getattr(nn, "Conv%dd" % ndims)
Conv = getattr(nn, 'Conv%dd' % ndims)
self.conv = Conv(in_channels, out_channels, 3, stride, 1)
if activation == "leaky":
if activation == 'leaky':
self.activation = nn.LeakyReLU(0.2)
elif activation is None:
self.activation = None
else:
raise ValueError(f"Unknown activation: {activation}")
raise ValueError(f'Unknown activation: {activation}')

def forward(self, x):
out = self.conv(x)
Expand Down
46 changes: 22 additions & 24 deletions nipreps/synthstrip/wrappers/nipype.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
# https://www.nipreps.org/community/licensing/
#
"""SynthStrip interface."""

import os
from pathlib import Path

from nipype.interfaces.base import (
CommandLine,
CommandLineInputSpec,
File,
TraitedSpec,
traits,
Undefined,
traits,
)

_fs_home = os.getenv("FREESURFER_HOME", None)
_default_model_path = Path(_fs_home) / "models" / "synthstrip.1.pt" if _fs_home else Undefined
_fs_home = os.getenv('FREESURFER_HOME', None)
_default_model_path = Path(_fs_home) / 'models' / 'synthstrip.1.pt' if _fs_home else Undefined

if _fs_home and not _default_model_path.exists():
_default_model_path = Undefined
Expand All @@ -43,43 +45,39 @@ class _SynthStripInputSpec(CommandLineInputSpec):
in_file = File(
exists=True,
mandatory=True,
argstr="-i %s",
desc="Input image to be brain extracted",
)
use_gpu = traits.Bool(
False, usedefault=True, argstr="-g", desc="Use GPU", nohash=True
argstr='-i %s',
desc='Input image to be brain extracted',
)
use_gpu = traits.Bool(False, usedefault=True, argstr='-g', desc='Use GPU', nohash=True)
model = File(
str(_default_model_path),
usedefault=True,
exists=True,
argstr="--model %s",
argstr='--model %s',
desc="file containing model's weights",
)
border_mm = traits.Int(
1, usedefault=True, argstr="-b %d", desc="Mask border threshold in mm"
)
border_mm = traits.Int(1, usedefault=True, argstr='-b %d', desc='Mask border threshold in mm')
out_file = File(
name_source=["in_file"],
name_template="%s_desc-brain.nii.gz",
argstr="-o %s",
desc="store brain-extracted input to file",
name_source=['in_file'],
name_template='%s_desc-brain.nii.gz',
argstr='-o %s',
desc='store brain-extracted input to file',
)
out_mask = File(
name_source=["in_file"],
name_template="%s_desc-brain_mask.nii.gz",
argstr="-m %s",
desc="store brainmask to file",
name_source=['in_file'],
name_template='%s_desc-brain_mask.nii.gz',
argstr='-m %s',
desc='store brainmask to file',
)
num_threads = traits.Int(desc="Number of threads", argstr="-n %d", nohash=True)
num_threads = traits.Int(desc='Number of threads', argstr='-n %d', nohash=True)


class _SynthStripOutputSpec(TraitedSpec):
out_file = File(desc="brain-extracted image")
out_mask = File(desc="brain mask")
out_file = File(desc='brain-extracted image')
out_mask = File(desc='brain mask')


class SynthStrip(CommandLine):
_cmd = "nipreps-synthstrip"
_cmd = 'nipreps-synthstrip'
input_spec = _SynthStripInputSpec
output_spec = _SynthStripOutputSpec
Loading

0 comments on commit 82e80ac

Please sign in to comment.