diff --git a/MANIFEST.in b/MANIFEST.in index 30ed8661..8d3d9882 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,3 +8,4 @@ include dmriprep/_version.py # data include dmriprep/config/reports-spec.yml +recursive-include * *.json diff --git a/dmriprep/config/emc_coarse_Affine.json b/dmriprep/config/emc_coarse_Affine.json new file mode 100644 index 00000000..5602e499 --- /dev/null +++ b/dmriprep/config/emc_coarse_Affine.json @@ -0,0 +1,25 @@ +{ + "dimension": 3, + "float": true, + "winsorize_lower_quantile": 0.002, + "winsorize_upper_quantile": 0.998, + "collapse_output_transforms": true, + "write_composite_transform": false, + "use_histogram_matching": [ false, false ], + "use_estimate_learning_rate_once": [ true, true ], + "transforms": [ "Rigid", "Affine" ], + "number_of_iterations": [ [ 100, 100 ], [ 100 ] ], + "output_warped_image": true, + "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], + "convergence_threshold": [ 1e-06, 1e-06 ], + "convergence_window_size": [ 20, 20 ], + "metric": [ "Mattes", "Mattes" ], + "sampling_percentage": [ 0.15, 0.2 ], + "sampling_strategy": [ "Random", "Random" ], + "smoothing_sigmas": [ [ 8.0, 2.0 ], [ 2.0 ] ], + "sigma_units": [ "mm", "mm" ], + "metric_weight": [ 1.0, 1.0 ], + "shrink_factors": [ [ 2, 1 ], [ 1 ] ], + "radius_or_number_of_bins": [ 48, 48 ], + "interpolation": "BSpline" +} diff --git a/dmriprep/config/emc_coarse_Rigid.json b/dmriprep/config/emc_coarse_Rigid.json new file mode 100644 index 00000000..f35c7a9c --- /dev/null +++ b/dmriprep/config/emc_coarse_Rigid.json @@ -0,0 +1,25 @@ +{ + "dimension": 3, + "float": true, + "winsorize_lower_quantile": 0.002, + "winsorize_upper_quantile": 0.998, + "collapse_output_transforms": true, + "write_composite_transform": false, + "use_histogram_matching": [ false ], + "use_estimate_learning_rate_once": [ true ], + "transforms": [ "Rigid" ], + "number_of_iterations": [ [ 100, 100 ] ], + "output_warped_image": true, + "transform_parameters": [ [ 0.2 ] ], + "convergence_threshold": [ 1e-06 ], + "convergence_window_size": [ 20 ], + "metric": [ "Mattes" ], + "sampling_percentage": [ 0.15 ], + "sampling_strategy": [ "Random" ], + "smoothing_sigmas": [ [ 8.0, 2.0 ] ], + "sigma_units": [ "mm"], + "metric_weight": [ 1.0 ], + "shrink_factors": [ [ 2, 1 ] ], + "radius_or_number_of_bins": [ 48 ], + "interpolation": "BSpline" +} diff --git a/dmriprep/config/emc_precise_Affine.json b/dmriprep/config/emc_precise_Affine.json new file mode 100644 index 00000000..924a23a4 --- /dev/null +++ b/dmriprep/config/emc_precise_Affine.json @@ -0,0 +1,25 @@ +{ + "dimension": 3, + "float": true, + "winsorize_lower_quantile": 0.002, + "winsorize_upper_quantile": 0.998, + "collapse_output_transforms": true, + "write_composite_transform": false, + "use_histogram_matching": [ false, false ], + "use_estimate_learning_rate_once": [ true, true ], + "transforms": [ "Rigid", "Affine" ], + "number_of_iterations": [ [ 1000, 1000 ], [ 1000 ] ], + "output_warped_image": true, + "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], + "convergence_threshold": [ 1e-08, 1e-08 ], + "convergence_window_size": [ 20, 20 ], + "metric": [ "Mattes", "Mattes" ], + "sampling_percentage": [ 0.15, 0.2 ], + "sampling_strategy": [ "Random", "Random" ], + "smoothing_sigmas": [ [ 8.0, 2.0 ], [ 2.0 ] ], + "sigma_units": [ "mm", "mm" ], + "metric_weight": [ 1.0, 1.0 ], + "shrink_factors": [ [ 2, 1 ], [ 1 ] ], + "radius_or_number_of_bins": [ 48, 48 ], + "interpolation": "BSpline" +} diff --git a/dmriprep/config/emc_precise_Rigid.json b/dmriprep/config/emc_precise_Rigid.json new file mode 100644 index 00000000..78127352 --- /dev/null +++ b/dmriprep/config/emc_precise_Rigid.json @@ -0,0 +1,25 @@ +{ + "dimension": 3, + "float": true, + "winsorize_lower_quantile": 0.002, + "winsorize_upper_quantile": 0.998, + "collapse_output_transforms": true, + "write_composite_transform": false, + "use_histogram_matching": [ false ], + "use_estimate_learning_rate_once": [ true ], + "transforms": [ "Rigid" ], + "number_of_iterations": [ [ 1000, 1000 ] ], + "output_warped_image": true, + "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], + "convergence_threshold": [ 1e-08, 1e-08 ], + "convergence_window_size": [ 20, 20 ], + "metric": [ "Mattes" ], + "sampling_percentage": [ 0.15 ], + "sampling_strategy": [ "Random" ], + "smoothing_sigmas": [ [ 8.0, 2.0 ] ], + "sigma_units": [ "mm" ], + "metric_weight": [ 1.0 ], + "shrink_factors": [ [ 2, 1 ] ], + "radius_or_number_of_bins": [ 48 ], + "interpolation": "BSpline" +} diff --git a/dmriprep/interfaces/bids.py b/dmriprep/interfaces/bids.py new file mode 100644 index 00000000..f64c4484 --- /dev/null +++ b/dmriprep/interfaces/bids.py @@ -0,0 +1,139 @@ +import os +import re +from pathlib import Path +from nipype.interfaces.base import ( + traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File, isdefined, + InputMultiObject, OutputMultiPath, OutputMultiObject, +) +from ..utils.bids import splitext as _splitext, _copy_any + +__all__ = ['BIDS_NAME'] + +BIDS_NAME = re.compile( + r'^(.*\/)?(?Psub-[a-zA-Z0-9]+)(_(?Pses-[a-zA-Z0-9]+))?' + '(_(?Ptask-[a-zA-Z0-9]+))?(_(?Pacq-[a-zA-Z0-9]+))?' + '(_(?Prec-[a-zA-Z0-9]+))?(_(?Prun-[a-zA-Z0-9]+))?') + + +class DerivativesDataSinkInputSpec(BaseInterfaceInputSpec): + base_directory = traits.Directory( + desc='Path to the base directory for storing data.') + in_file = InputMultiObject(File(exists=True), mandatory=True, + desc='the object to be saved') + source_file = File(exists=False, mandatory=True, desc='the original file') + prefix = traits.Str(mandatory=False, desc='prefix for output files') + space = traits.Str('', usedefault=True, desc='Label for space field') + desc = traits.Str('', usedefault=True, desc='Label for description field') + suffix = traits.Str('', usedefault=True, desc='suffix appended to source_file') + keep_dtype = traits.Bool(False, usedefault=True, desc='keep datatype suffix') + extra_values = traits.List(traits.Str) + compress = traits.Bool(desc="force compression (True) or uncompression (False)" + " of the output file (default: same as input)") + extension = traits.Str() + + +class DerivativesDataSinkOutputSpec(TraitedSpec): + out_file = OutputMultiObject(File(exists=True, desc='written file path')) + compression = OutputMultiPath( + traits.Bool, desc='whether ``in_file`` was compressed/uncompressed ' + 'or `it was copied directly.') + + +class DerivativesDataSink(SimpleInterface): + """ + Saves the `in_file` into a BIDS-Derivatives folder provided + by `base_directory`, given the input reference `source_file`. + >>> from pathlib import Path + >>> import tempfile + >>> from dmriprep.utils.bids import collect_data + >>> tmpdir = Path(tempfile.mkdtemp()) + >>> tmpfile = tmpdir / 'a_temp_file.nii.gz' + >>> tmpfile.open('w').close() # "touch" the file + >>> dsink = DerivativesDataSink(base_directory=str(tmpdir)) + >>> dsink.inputs.in_file = str(tmpfile) + >>> dsink.inputs.source_file = collect_data('ds114', '01')[0]['t1w'][0] + >>> dsink.inputs.keep_dtype = True + >>> dsink.inputs.suffix = 'target-mni' + >>> res = dsink.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS + '.../dmriprep/sub-01/ses-retest/anat/sub-01_ses-retest_target-mni_T1w.nii.gz' + >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' + >>> bids_dir.mkdir(parents=True, exist_ok=True) + >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' + >>> tricky_source.open('w').close() + >>> dsink = DerivativesDataSink(base_directory=str(tmpdir)) + >>> dsink.inputs.in_file = str(tmpfile) + >>> dsink.inputs.source_file = str(tricky_source) + >>> dsink.inputs.keep_dtype = True + >>> dsink.inputs.desc = 'preproc' + >>> res = dsink.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS + '.../dmriprep/sub-02/ses-noanat/func/sub-02_ses-noanat_task-rest_run-01_\ +desc-preproc_bold.nii.gz' + """ + input_spec = DerivativesDataSinkInputSpec + output_spec = DerivativesDataSinkOutputSpec + out_path_base = "dmriprep" + _always_run = True + + def __init__(self, out_path_base=None, **inputs): + super(DerivativesDataSink, self).__init__(**inputs) + self._results['out_file'] = [] + if out_path_base: + self.out_path_base = out_path_base + + def _run_interface(self, runtime): + src_fname, _ = _splitext(self.inputs.source_file) + src_fname, dtype = src_fname.rsplit('_', 1) + _, ext = _splitext(self.inputs.in_file[0]) + if self.inputs.compress is True and not ext.endswith('.gz'): + ext += '.gz' + elif self.inputs.compress is False and ext.endswith('.gz'): + ext = ext[:-3] + + m = BIDS_NAME.search(src_fname) + + mod = os.path.basename(os.path.dirname(self.inputs.source_file)) + + base_directory = runtime.cwd + if isdefined(self.inputs.base_directory): + base_directory = str(self.inputs.base_directory) + + out_path = '{}/{subject_id}'.format(self.out_path_base, **m.groupdict()) + if m.groupdict().get('session_id') is not None: + out_path += '/{session_id}'.format(**m.groupdict()) + out_path += '/{}'.format(mod) + + out_path = os.path.join(base_directory, out_path) + + os.makedirs(out_path, exist_ok=True) + + if isdefined(self.inputs.prefix): + base_fname = os.path.join(out_path, self.inputs.prefix) + else: + base_fname = os.path.join(out_path, src_fname) + + formatstr = '{bname}{space}{desc}{suffix}{dtype}{ext}' + if len(self.inputs.in_file) > 1 and not isdefined(self.inputs.extra_values): + formatstr = '{bname}{space}{desc}{suffix}{i:04d}{dtype}{ext}' + + space = '_space-{}'.format(self.inputs.space) if self.inputs.space else '' + desc = '_desc-{}'.format(self.inputs.desc) if self.inputs.desc else '' + suffix = '_{}'.format(self.inputs.suffix) if self.inputs.suffix else '' + dtype = '' if not self.inputs.keep_dtype else ('_%s' % dtype) + + self._results['compression'] = [] + for i, fname in enumerate(self.inputs.in_file): + out_file = formatstr.format( + bname=base_fname, + space=space, + desc=desc, + suffix=suffix, + i=i, + dtype=dtype, + ext=ext) + if isdefined(self.inputs.extra_values): + out_file = out_file.format(extra_value=self.inputs.extra_values[i]) + self._results['out_file'].append(out_file) + self._results['compression'].append(_copy_any(fname, out_file)) + return runtime \ No newline at end of file diff --git a/dmriprep/interfaces/images.py b/dmriprep/interfaces/images.py index a59fedfb..e17f5b73 100644 --- a/dmriprep/interfaces/images.py +++ b/dmriprep/interfaces/images.py @@ -1,11 +1,17 @@ """Image tools interfaces.""" +import os import numpy as np import nibabel as nb -from nipype.utils.filemanip import fname_presuffix from nipype import logging +from nipype.interfaces import ants +from nipype.utils.filemanip import split_filename from nipype.interfaces.base import ( - traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File + traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File, isdefined, + InputMultiObject, OutputMultiObject, CommandLine ) +from ..utils.images import extract_b0, rescale_b0, median, quick_load_images, match_transforms, prune_b0s_from_dwis +from ..utils.vectors import _nonoverlapping_qspace_samples + LOGGER = logging.getLogger('nipype.interface') @@ -45,31 +51,13 @@ def _run_interface(self, runtime): return runtime -def extract_b0(in_file, b0_ixs, newpath=None): - """Extract the *b0* volumes from a DWI dataset.""" - out_file = fname_presuffix( - in_file, suffix='_b0', newpath=newpath) - - img = nb.load(in_file) - data = img.get_fdata(dtype='float32') - - b0 = data[..., b0_ixs] - - hdr = img.header.copy() - hdr.set_data_shape(b0.shape) - hdr.set_xyzt_units('mm') - hdr.set_data_dtype(np.float32) - nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file) - return out_file - - class _RescaleB0InputSpec(BaseInterfaceInputSpec): in_file = File(exists=True, mandatory=True, desc='b0s file') mask_file = File(exists=True, mandatory=True, desc='mask file') class _RescaleB0OutputSpec(TraitedSpec): - out_ref = File(exists=True, desc='One average b0 file') + out_ref = File(exists=True, desc='One median b0 file') out_b0s = File(exists=True, desc='series of rescaled b0 volumes') @@ -103,43 +91,259 @@ def _run_interface(self, runtime): return runtime -def rescale_b0(in_file, mask_file, newpath=None): - """Rescale the input volumes using the median signal intensity.""" - out_file = fname_presuffix( - in_file, suffix='_rescaled_b0', newpath=newpath) +class MatchTransformsInputSpec(BaseInterfaceInputSpec): + b0_indices = traits.List(mandatory=True) + dwi_files = InputMultiObject(File(exists=True), mandatory=True) + transforms = InputMultiObject(File(exists=True), mandatory=True) + + +class MatchTransformsOutputSpec(TraitedSpec): + transforms = OutputMultiObject(File(exists=True), mandatory=True) + + +class MatchTransforms(SimpleInterface): + input_spec = MatchTransformsInputSpec + output_spec = MatchTransformsOutputSpec + + def _run_interface(self, runtime): + self._results['transforms'] = match_transforms(self.inputs.dwi_files, + self.inputs.transforms, + self.inputs.b0_indices) + return runtime + + +class N3BiasFieldCorrection(ants.N4BiasFieldCorrection): + _cmd = "N3BiasFieldCorrection" + + +class ImageMathInputSpec(BaseInterfaceInputSpec): + in_file = File(exists=True, mandatory=True, position=3, argstr='%s') + dimension = traits.Enum(3, 2, 4, usedefault=True, argstr="%d", position=0) + out_file = File(argstr="%s", genfile=True, position=1) + operation = traits.Str(argstr="%s", position=2) + secondary_arg = traits.Str("", argstr="%s") + secondary_file = File(argstr="%s") + + +class ImageMathOutputSpec(TraitedSpec): + out_file = File(exists=True) + + +class ImageMath(CommandLine): + input_spec = ImageMathInputSpec + output_spec = ImageMathOutputSpec + _cmd = 'ImageMath' + + def _gen_filename(self, name): + if name == 'out_file': + output = self.inputs.out_file + if not isdefined(output): + _, fname, ext = split_filename(self.inputs.in_file) + output = fname + "_" + self.inputs.operation + ext + return output + return None + + def _list_outputs(self): + outputs = self.output_spec().get() + outputs['out_file'] = os.path.abspath(self._gen_filename('out_file')) + return outputs + + +class SignalPredictionInputSpec(BaseInterfaceInputSpec): + aligned_dwi_files = InputMultiObject(File(exists=True), mandatory=True) + aligned_vectors = File(exists=True, mandatory=True) + b0_mask = File(exists=True, mandatory=True) + b0_median = File(exists=True, mandatory=True) + bvec_to_predict = traits.Array() + bval_to_predict = traits.Float() + minimal_q_distance = traits.Float(2.0, usedefault=True) + b0_indices = traits.List() + + +class SignalPredictionOutputSpec(TraitedSpec): + predicted_image = File(exists=True) + + +class SignalPrediction(SimpleInterface): + """ + """ + input_spec = SignalPredictionInputSpec + output_spec = SignalPredictionOutputSpec + + def _run_interface(self, runtime, model_name='tensor'): + import warnings + warnings.filterwarnings("ignore") + from dipy.core.gradients import gradient_table_from_bvals_bvecs + pred_vec = self.inputs.bvec_to_predict + pred_val = self.inputs.bval_to_predict + + # Load the mask image: + mask_img = nb.load(self.inputs.b0_mask) + mask_array = mask_img.get_data() > 1e-6 + + all_images = prune_b0s_from_dwis(self.inputs.aligned_dwi_files, self.inputs.b0_indices) + + # Load the vectors + ras_b_mat = np.genfromtxt(self.inputs.aligned_vectors, delimiter='\t') + all_bvecs = np.row_stack([np.zeros(3), np.delete(ras_b_mat[:, 0:3], self.inputs.b0_indices, axis=0)]) + all_bvals = np.concatenate([np.zeros(1), np.delete(ras_b_mat[:, 3], self.inputs.b0_indices)]) + + # Which sample points are too close to the one we want to predict? + training_mask = _nonoverlapping_qspace_samples( + pred_val, pred_vec, all_bvals, all_bvecs, self.inputs.minimal_q_distance) + training_indices = np.flatnonzero(training_mask[1:]) + training_image_paths = [self.inputs.b0_median] + [ + all_images[idx] for idx in training_indices] + training_bvecs = all_bvecs[training_mask] + training_bvals = all_bvals[training_mask] + # print("Training with volumes: {}".format(str(training_indices))) + + # Load training data and fit the model + training_data = quick_load_images(training_image_paths) + + # Build gradient table object + training_gtab = gradient_table_from_bvals_bvecs(training_bvals, training_bvecs, b0_threshold=0) + + # Checked shelledness + if len(np.unique(training_gtab.bvals)) > 2: + is_shelled = True + else: + is_shelled = False + + # Get the vector for the desired coordinate + prediction_gtab = gradient_table_from_bvals_bvecs(np.array(pred_val)[None], np.array(pred_vec)[None, :], + b0_threshold=0) + + if is_shelled and model_name == '3dshore': + from dipy.reconst.shore import ShoreModel + radial_order = 6 + zeta = 700 + lambdaN = 1e-8 + lambdaL = 1e-8 + estimator_shore = ShoreModel(training_gtab, radial_order=radial_order, + zeta=zeta, lambdaN=lambdaN, lambdaL=lambdaL) + estimator_shore_fit = estimator_shore.fit(training_data, mask=mask_array) + pred_shore_fit = estimator_shore_fit.predict(prediction_gtab) + pred_shore_fit_file = os.path.join(runtime.cwd, + "predicted_shore_b%d_%.2f_%.2f_%.2f.nii.gz" % + ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + output_data = pred_shore_fit[..., 0] + nb.Nifti1Image(output_data, mask_img.affine, mask_img.header).to_filename(pred_shore_fit_file) + elif model_name == 'sfm' and not is_shelled: + import dipy.reconst.sfm as sfm + from dipy.data import default_sphere + + estimator_sfm = sfm.SparseFascicleModel(training_gtab, sphere=default_sphere, + l1_ratio=0.5, alpha=0.001) + estimator_sfm_fit = estimator_sfm.fit(training_data, mask=mask_array) + pred_sfm_fit = estimator_sfm_fit.predict(prediction_gtab)[..., 0] + pred_sfm_fit[~mask_array] = 0 + pred_fit_file = os.path.join(runtime.cwd, "predicted_sfm_b%d_%.2f_%.2f_%.2f.npy" % + ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + np.save(pred_fit_file, pred_sfm_fit) + pred_fit_file = os.path.join(runtime.cwd, "predicted_sfm_b%d_%.2f_%.2f_%.2f.nii.gz" % + ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + nb.Nifti1Image(pred_sfm_fit, mask_img.affine, mask_img.header).to_filename(pred_fit_file) + else: + from dipy.reconst.dti import TensorModel + estimator_ten = TensorModel(training_gtab) + estimator_ten_fit = estimator_ten.fit(training_data, mask=mask_array) + pred_ten_fit = estimator_ten_fit.predict(prediction_gtab)[..., 0] + pred_ten_fit[~mask_array] = 0 + pred_fit_file = os.path.join(runtime.cwd, "predicted_ten_b%d_%.2f_%.2f_%.2f.npy" % + ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + np.save(pred_fit_file, pred_ten_fit) + pred_fit_file = os.path.join(runtime.cwd, "predicted_ten_b%d_%.2f_%.2f_%.2f.nii.gz" % + ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + nb.Nifti1Image(pred_ten_fit, mask_img.affine, mask_img.header).to_filename(pred_fit_file) + + self._results['predicted_image'] = pred_fit_file + + return runtime + + +class CalculateCNRInputSpec(BaseInterfaceInputSpec): + emc_warped_images = InputMultiObject(File(exists=True)) + predicted_images = InputMultiObject(File(exists=True)) + mask_image = File(exists=True) + - img = nb.load(in_file) - if img.dataobj.ndim == 3: - return in_file +class CalculateCNROutputSpec(TraitedSpec): + cnr_image = File(exists=True) - data = img.get_fdata(dtype='float32') - mask_img = nb.load(mask_file) - mask_data = mask_img.get_fdata(dtype='float32') - median_signal = np.median(data[mask_data > 0, ...], axis=0) - rescaled_data = 1000 * data / median_signal - hdr = img.header.copy() - nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file) - return out_file +class CalculateCNR(SimpleInterface): + input_spec = CalculateCNRInputSpec + output_spec = CalculateCNROutputSpec + + def _run_interface(self, runtime): + cnr_file = os.path.join(runtime.cwd, "emc_CNR.nii.gz") + model_images = quick_load_images(self.inputs.predicted_images) + observed_images = quick_load_images(self.inputs.emc_warped_images) + mask_image = nb.load(self.inputs.mask_image) + mask = mask_image.get_data() > 1e-6 + signal_vals = model_images[mask] + b0 = signal_vals[:, 0][:, np.newaxis] + signal_vals = signal_vals / b0 + signal_var = np.var(signal_vals, 1) + observed_vals = observed_images[mask] / b0 + noise_var = np.var(signal_vals - observed_vals, 1) + snr = np.nan_to_num(signal_var / noise_var) + out_mat = np.zeros(mask_image.shape) + out_mat[mask] = snr + nb.Nifti1Image(out_mat, mask_image.affine, + header=mask_image.header).to_filename(cnr_file) + self._results['cnr_image'] = cnr_file + return runtime -def median(in_file, newpath=None): - """Average a 4D dataset across the last dimension using median.""" - out_file = fname_presuffix( - in_file, suffix='_b0ref', newpath=newpath) +class ReorderOutputsInputSpec(BaseInterfaceInputSpec): + b0_indices = traits.List(mandatory=True) + b0_median = File(exists=True, mandatory=True) + warped_b0_images = InputMultiObject(File(exists=True), mandatory=True) + warped_dwi_images = InputMultiObject(File(exists=True), mandatory=True) + initial_transforms = InputMultiObject(File(exists=True), mandatory=True) + model_based_transforms = InputMultiObject(traits.List(), mandatory=True) + model_predicted_images = InputMultiObject(File(exists=True), mandatory=True) - img = nb.load(in_file) - if img.dataobj.ndim == 3: - return in_file - if img.shape[-1] == 1: - nb.squeeze_image(img).to_filename(out_file) - return out_file - median_data = np.median(img.get_fdata(dtype='float32'), - axis=-1) +class ReorderOutputsOutputSpec(TraitedSpec): + full_transforms = OutputMultiObject(traits.List()) + full_predicted_dwi_series = OutputMultiObject(File(exists=True)) + emc_warped_images = OutputMultiObject(File(exists=True)) - hdr = img.header.copy() - hdr.set_xyzt_units('mm') - hdr.set_data_dtype(np.float32) - nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file) - return out_file + +class ReorderOutputs(SimpleInterface): + input_spec = ReorderOutputsInputSpec + output_spec = ReorderOutputsOutputSpec + + def _run_interface(self, runtime): + full_transforms = [] + full_predicted_dwi_series = [] + full_warped_images = [] + warped_b0_images = self.inputs.warped_b0_images[::-1] + warped_dwi_images = self.inputs.warped_dwi_images[::-1] + model_transforms = self.inputs.model_based_transforms[::-1] + model_images = self.inputs.model_predicted_images[::-1] + b0_transforms = [self.inputs.initial_transforms[idx] for idx in + self.inputs.b0_indices][::-1] + num_dwis = len(self.inputs.initial_transforms) + + for imagenum in range(num_dwis): + if imagenum in self.inputs.b0_indices: + full_predicted_dwi_series.append(self.inputs.b0_median) + full_transforms.append(b0_transforms.pop()) + full_warped_images.append(warped_b0_images.pop()) + else: + full_transforms.append(model_transforms.pop()) + full_predicted_dwi_series.append(model_images.pop()) + full_warped_images.append(warped_dwi_images.pop()) + + if not len(model_transforms) == len(b0_transforms) == len(model_images) == 0: + raise Exception("Unable to recombine images and transforms") + + self._results['emc_warped_images'] = full_warped_images + self._results['full_transforms'] = full_transforms + self._results['full_predicted_dwi_series'] = full_predicted_dwi_series + + return runtime diff --git a/dmriprep/interfaces/reports.py b/dmriprep/interfaces/reports.py index 0e14c2e0..e1b53fdc 100644 --- a/dmriprep/interfaces/reports.py +++ b/dmriprep/interfaces/reports.py @@ -4,13 +4,13 @@ import os import time - +import pandas as pd from nipype.interfaces.base import ( traits, TraitedSpec, BaseInterfaceInputSpec, File, Directory, InputMultiObject, Str, isdefined, SimpleInterface) from nipype.interfaces import freesurfer as fs - +from ..utils.viz import _iteration_summary_plot, before_after_images SUBJECT_TEMPLATE = """\ \t
    @@ -121,3 +121,73 @@ def _generate_segment(self): return ABOUT_TEMPLATE.format(version=self.inputs.version, command=self.inputs.command, date=time.strftime("%Y-%m-%d %H:%M:%S %z")) + + +class IterationSummaryInputSpec(BaseInterfaceInputSpec): + collected_motion_files = InputMultiObject(File(exists=True)) + + +class IterationSummaryOutputSpec(TraitedSpec): + iteration_summary_file = File(exists=True) + plot_file = File(exists=True) + + +class IterationSummary(SummaryInterface): + input_spec = IterationSummaryInputSpec + output_spec = IterationSummaryOutputSpec + + def _run_interface(self, runtime): + motion_files = self.inputs.collected_motion_files + output_fname = os.path.join(runtime.cwd, "iteration_summary.csv") + fig_output_fname = os.path.join(runtime.cwd, "iterdiffs.svg") + if not isdefined(motion_files): + return runtime + + all_iters = [] + for fnum, fname in enumerate(motion_files): + df = pd.read_csv(fname) + df['iter_num'] = fnum + path_parts = fname.split(os.sep) + itername = '' if 'iter' not in path_parts[-3] else path_parts[-3] + df['iter_name'] = itername + all_iters.append(df) + combined = pd.concat(all_iters, axis=0, ignore_index=True) + + combined.to_csv(output_fname, index=False) + self._results['iteration_summary_file'] = output_fname + + # Create a figure for the report + _iteration_summary_plot(combined, fig_output_fname) + self._results['plot_file'] = fig_output_fname + + return runtime + + +class HMCReportInputSpec(BaseInterfaceInputSpec): + iteration_summary = File(exists=True) + registered_images = InputMultiObject(File(exists=True)) + original_images = InputMultiObject(File(exists=True)) + model_predicted_images = InputMultiObject(File(exists=True)) + + +class HMCReportOutputSpec(SummaryOutputSpec): + plot_file = File(exists=True) + + +class HMCReport(SummaryInterface): + input_spec = HMCReportInputSpec + output_spec = HMCReportOutputSpec + + def _run_interface(self, runtime): + import imageio + images = [] + for imagenum, (orig_file, aligned_file, model_file) in enumerate(zip( + self.inputs.original_images, self.inputs.registered_images, + self.inputs.model_predicted_images)): + + images.extend(before_after_images(orig_file, aligned_file, model_file, imagenum)) + + out_file = os.path.join(runtime.cwd, "emc_reg.gif") + imageio.mimsave(out_file, images, fps=1) + self._results['plot_file'] = out_file + return runtime diff --git a/dmriprep/interfaces/vectors.py b/dmriprep/interfaces/vectors.py index 6d9da11b..6880f623 100644 --- a/dmriprep/interfaces/vectors.py +++ b/dmriprep/interfaces/vectors.py @@ -1,12 +1,21 @@ """Handling the gradient table.""" +import os from pathlib import Path import numpy as np +import pandas as pd from nipype.utils.filemanip import fname_presuffix from nipype.interfaces.base import ( SimpleInterface, BaseInterfaceInputSpec, TraitedSpec, - File, traits, isdefined + File, traits, isdefined, InputMultiObject ) -from ..utils.vectors import DiffusionGradientTable, B0_THRESHOLD, BVEC_NORM_EPSILON +from ..utils.vectors import DiffusionGradientTable, reorient_vecs_from_ras_b, B0_THRESHOLD, BVEC_NORM_EPSILON +from subprocess import Popen, PIPE + +def _undefined(objekt, name, default=None): + value = getattr(objekt, name) + if not isdefined(value): + return default + return value class _CheckGradientTableInputSpec(BaseInterfaceInputSpec): @@ -92,8 +101,128 @@ def _run_interface(self, runtime): return runtime -def _undefined(objekt, name, default=None): - value = getattr(objekt, name) - if not isdefined(value): - return default - return value +class _ReorientVectorsInputSpec(BaseInterfaceInputSpec): + rasb_file = File(exists=True) + affines = traits.List() + b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True) + + +class _ReorientVectorsOutputSpec(TraitedSpec): + out_rasb = File(exists=True) + + +class ReorientVectors(SimpleInterface): + """ + Reorient Vectors + Example + ------- + >>> os.chdir(tmpdir) + >>> oldrasb = str(data_dir / 'dwi.tsv') + >>> oldrasb_mat = np.loadtxt(str(data_dir / 'dwi.tsv'), skiprows=1) + >>> # The simple case: all affines are identity + >>> affine_list = np.zeros((len(oldrasb_mat[:, 3][oldrasb_mat[:, 3] != 0]), 4, 4)) + >>> for i in range(4): + >>> affine_list[:, i, i] = 1 + >>> reor_vecs = ReorientVectors() + >>> reor_vecs = ReorientVectors() + >>> reor_vecs.inputs.affines = affine_list + >>> reor_vecs.inputs.in_rasb = oldrasb + >>> res = reor_vecs.run() + >>> out_rasb = res.outputs.out_rasb + >>> out_rasb_mat = np.loadtxt(out_rasb, skiprows=1) + >>> npt.assert_equal(oldrasb_mat, out_rasb_mat) + True + """ + + input_spec = _ReorientVectorsInputSpec + output_spec = _ReorientVectorsOutputSpec + + def _run_interface(self, runtime): + from nipype.utils.filemanip import fname_presuffix + reor_table = reorient_vecs_from_ras_b( + rasb_file=self.inputs.rasb_file, + affines=self.inputs.affines, + b0_threshold=self.inputs.b0_threshold, + ) + + cwd = Path(runtime.cwd).absolute() + reor_rasb_file = fname_presuffix( + self.inputs.rasb_file, use_ext=False, suffix='_reoriented.tsv', + newpath=str(cwd)) + np.savetxt(str(reor_rasb_file), reor_table, + delimiter='\t', header='\t'.join('RASB'), + fmt=['%.8f'] * 3 + ['%g']) + + self._results['out_rasb'] = reor_rasb_file + return runtime + + +def get_fsl_motion_params(itk_file, src_file, ref_file, working_dir): + tmp_fsl_file = fname_presuffix(itk_file, newpath=working_dir, + suffix='_FSL.xfm', use_ext=False) + fsl_convert_cmd = "c3d_affine_tool " \ + "-ref {ref_file} " \ + "-src {src_file} " \ + "-itk {itk_file} " \ + "-ras2fsl -o {fsl_file}".format( + src_file=src_file, ref_file=ref_file, itk_file=itk_file, + fsl_file=tmp_fsl_file) + os.system(fsl_convert_cmd) + proc = Popen(['avscale', '--allparams', tmp_fsl_file, src_file], stdout=PIPE, + stderr=PIPE) + stdout, _ = proc.communicate() + + def get_measures(line): + line = line.strip().split() + return np.array([float(num) for num in line[-3:]]) + + lines = stdout.decode("utf-8").split("\n") + flip = np.array([1, -1, -1]) + rotation = get_measures(lines[6]) * flip + translation = get_measures(lines[8]) * flip + scale = get_measures(lines[10]) + shear = get_measures(lines[12]) + + return np.concatenate([scale, shear, rotation, translation]) + + +class CombineMotionsInputSpec(BaseInterfaceInputSpec): + transform_files = InputMultiObject(File(exists=True), mandatory=True, + desc='transform files from hmc') + source_files = InputMultiObject(File(exists=True), mandatory=True, + desc='Moving images') + ref_file = File(exists=True, mandatory=True, desc='Fixed Image') + + +class CombineMotionsOututSpec(TraitedSpec): + motion_file = File(exists=True) + spm_motion_file = File(exists=True) + + +class CombineMotions(SimpleInterface): + input_spec = CombineMotionsInputSpec + output_spec = CombineMotionsOututSpec + + def _run_interface(self, runtime): + collected_motion = [] + output_fname = os.path.join(runtime.cwd, "motion_params.csv") + output_spm_fname = os.path.join(runtime.cwd, "spm_movpar.txt") + ref_file = self.inputs.ref_file + for motion_file, src_file in zip(self.inputs.transform_files, + self.inputs.source_files): + collected_motion.append( + get_fsl_motion_params(motion_file, src_file, ref_file, runtime.cwd)) + + final_motion = np.row_stack(collected_motion) + cols = ["scaleX", "scaleY", "scaleZ", "shearXY", "shearXZ", + "shearYZ", "rotateX", "rotateY", "rotateZ", "shiftX", "shiftY", + "shiftZ"] + motion_df = pd.DataFrame(data=final_motion, columns=cols) + motion_df.to_csv(output_fname, index=False) + self._results['motion_file'] = output_fname + + spmcols = motion_df[['shiftX', 'shiftY', 'shiftZ', 'rotateX', 'rotateY', 'rotateZ']] + self._results['spm_motion_file'] = output_spm_fname + np.savetxt(output_spm_fname, spmcols.values) + + return runtime diff --git a/dmriprep/utils/bids.py b/dmriprep/utils/bids.py index 3178fd8c..4212b65a 100644 --- a/dmriprep/utils/bids.py +++ b/dmriprep/utils/bids.py @@ -160,3 +160,50 @@ def validate_input_dir(exec_env, bids_dir, participant_label): def _get_shub_version(singularity_url): return NotImplemented + + +def splitext(fname): + """Splits filename and extension (.gz safe) + >>> splitext('some/file.nii.gz') + ('file', '.nii.gz') + >>> splitext('some/other/file.nii') + ('file', '.nii') + >>> splitext('otherext.tar.gz') + ('otherext', '.tar.gz') + >>> splitext('text.txt') + ('text', '.txt') + """ + from pathlib import Path + basename = str(Path(fname).name) + stem = Path(basename.rstrip('.gz')).stem + return stem, basename[len(stem):] + + +def _copy_any(src, dst): + import os + import gzip + from shutil import copyfileobj + from nipype.utils.filemanip import copyfile + + src_isgz = src.endswith('.gz') + dst_isgz = dst.endswith('.gz') + if not src_isgz and not dst_isgz: + copyfile(src, dst, copy=True, use_hardlink=True) + return False # Make sure we do not reuse the hardlink later + + # Unlink target (should not exist) + if os.path.exists(dst): + os.unlink(dst) + + src_open = gzip.open if src_isgz else open + with src_open(src, 'rb') as f_in: + with open(dst, 'wb') as f_out: + if dst_isgz: + # Remove FNAME header from gzip (poldracklab/fmriprep#1480) + gz_out = gzip.GzipFile('', 'wb', 9, f_out, 0.) + copyfileobj(f_in, gz_out) + gz_out.close() + else: + copyfileobj(f_in, f_out) + + return True \ No newline at end of file diff --git a/dmriprep/utils/images.py b/dmriprep/utils/images.py new file mode 100644 index 00000000..7840bb39 --- /dev/null +++ b/dmriprep/utils/images.py @@ -0,0 +1,125 @@ +import numpy as np +import nibabel as nb +from nipype.utils.filemanip import fname_presuffix + + +def extract_b0(in_file, b0_ixs, newpath=None): + """Extract the *b0* volumes from a DWI dataset.""" + out_file = fname_presuffix( + in_file, suffix='_b0', newpath=newpath) + + img = nb.load(in_file) + data = img.get_fdata(dtype='float32') + + b0 = data[..., b0_ixs] + + hdr = img.header.copy() + hdr.set_data_shape(b0.shape) + hdr.set_xyzt_units('mm') + hdr.set_data_dtype(np.float32) + nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file) + return out_file + + +def rescale_b0(in_file, mask_file, newpath=None): + """Rescale the input volumes using the median signal intensity.""" + out_file = fname_presuffix( + in_file, suffix='_rescaled_b0', newpath=newpath) + + img = nb.load(in_file) + if img.dataobj.ndim == 3: + return in_file + + data = img.get_fdata(dtype='float32') + mask_img = nb.load(mask_file) + mask_data = mask_img.get_fdata(dtype='float32') + + median_signal = np.median(data[mask_data > 0, ...], axis=0) + rescaled_data = 1000 * data / median_signal + hdr = img.header.copy() + nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file) + return out_file + + +def median(in_file, newpath=None): + """Average a 4D dataset across the last dimension using median.""" + out_file = fname_presuffix( + in_file, suffix='_b0ref', newpath=newpath) + + img = nb.load(in_file) + if img.dataobj.ndim == 3: + return in_file + if img.shape[-1] == 1: + nb.squeeze_image(img).to_filename(out_file) + return out_file + + median_data = np.median(img.get_fdata(dtype='float32'), + axis=-1) + + hdr = img.header.copy() + hdr.set_xyzt_units('mm') + hdr.set_data_dtype(np.float32) + nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file) + return out_file + + +def quick_load_images(image_list, dtype=np.float32): + example_img = nb.load(image_list[0]) + num_images = len(image_list) + output_matrix = np.zeros(tuple(example_img.shape) + (num_images,), dtype=dtype) + for image_num, image_path in enumerate(image_list): + output_matrix[..., image_num] = nb.load(image_path).get_fdata(dtype=dtype) + return output_matrix + + +def match_transforms(dwi_files, transforms, b0_indices): + original_b0_indices = np.array(b0_indices) + num_dwis = len(dwi_files) + num_transforms = len(transforms) + + if num_dwis == num_transforms: + return transforms + + # Do sanity checks + if not len(transforms) == len(b0_indices): + raise Exception('number of transforms does not match number of b0 images') + + # Create a list of which emc affines go with each of the split images + nearest_affines = [] + for index in range(num_dwis): + nearest_b0_num = np.argmin(np.abs(index - original_b0_indices)) + this_transform = transforms[nearest_b0_num] + nearest_affines.append(this_transform) + + return nearest_affines + + +def save_4d_to_3d(in_file): + files_3d = nb.four_to_three(nb.load(in_file)) + out_files = [] + for i, file_3d in enumerate(files_3d): + out_file = fname_presuffix(in_file, suffix='_tmp_{}'.format(i)) + file_3d.to_filename(out_file) + out_files.append(out_file) + del files_3d + return out_files + + +def prune_b0s_from_dwis(in_files, b0_ixs): + if in_files[0].endswith('_trans.nii.gz'): + out_files = [i for j, i in enumerate(sorted(in_files, + key=lambda x: int(x.split("_")[-2].split('.nii.gz')[0]))) if j + not in b0_ixs] + else: + out_files = [i for j, i in enumerate(sorted(in_files, + key=lambda x: int(x.split("_")[-1].split('.nii.gz')[0]))) if j + not in b0_ixs] + return out_files + + +def save_3d_to_4d(in_files): + img_4d = nb.funcs.concat_images([nb.load(img_3d) for img_3d in in_files]) + out_file = fname_presuffix(in_files[0], suffix='_merged') + img_4d.to_filename(out_file) + del img_4d + return out_file diff --git a/dmriprep/utils/vectors.py b/dmriprep/utils/vectors.py index b63afad6..425b77d0 100644 --- a/dmriprep/utils/vectors.py +++ b/dmriprep/utils/vectors.py @@ -377,7 +377,7 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2): return rotated_bvecs -def reorient_bvecs_from_ras_b(ras_b, affines): +def reorient_vecs_from_ras_b(rasb_file, affines, b0_threshold=B0_THRESHOLD): """ Reorient the vectors from a rasb .tsv file. When correcting for motion, rotation of the diffusion-weighted volumes @@ -385,22 +385,99 @@ def reorient_bvecs_from_ras_b(ras_b, affines): and MD, and also cause characteristic biases in tractography, unless the gradient directions are appropriately reoriented to compensate for this effect [Leemans2009]_. - Parameters ---------- rasb_file : str or os.pathlike - File path to a RAS-B gradient table. If rasb_file is provided, - then bvecs and bvals will be dismissed. - + File path to a RAS-B gradient table. affines : list or ndarray of shape (n, 4, 4) or (n, 3, 3) Each entry in this list or array contain either an affine transformation (4,4) or a rotation matrix (3, 3). In both cases, the transformations encode the rotation that was applied to the image corresponding to one of the non-zero gradient directions. + Returns + ------- + Gradients : ndarray of shape (4, n) + A reoriented ndarray where the first three columns correspond to each of + x, y, z directions of the bvecs, in R-A-S image orientation. """ from dipy.core.gradients import gradient_table_from_bvals_bvecs, reorient_bvecs + from scipy.io import loadmat + + ras_b_mat = np.genfromtxt(rasb_file, delimiter='\t') + + # Verify that number of non-B0 volumes corresponds to the number of affines. + # If not, raise an error. + if len(ras_b_mat[:, 3][ras_b_mat[:, 3] <= b0_threshold]) != len(affines): + b0_indices = np.where(ras_b_mat[:, 3] <= b0_threshold)[0].tolist() + for i in sorted(b0_indices, reverse=True): + del affines[i] + if len(ras_b_mat[:, 3][ras_b_mat[:, 3] > b0_threshold]) != len(affines): + raise ValueError('Affine transformations do not correspond to gradients') + + # Build gradient table object + gt = gradient_table_from_bvals_bvecs(ras_b_mat[:, 3], ras_b_mat[:, 0:3], + b0_threshold=b0_threshold) + + # Reorient table + ras_trans = np.ones(shape=(4, 4)) + ras_trans[0, 1] = -ras_trans[0, 1] + ras_trans[1, 0] = -ras_trans[1, 0] + ras_trans[2, 3] = -ras_trans[2, 3] + affines_ras = [] + for aff in affines: + aff_mat = loadmat(aff) + M = np.zeros(shape=(4, 4)) + M[0, 0] = aff_mat['AffineTransform_float_3_3'][0] + M[0, 1] = aff_mat['AffineTransform_float_3_3'][1] + M[0, 2] = aff_mat['AffineTransform_float_3_3'][2] + M[1, 0] = aff_mat['AffineTransform_float_3_3'][3] + M[1, 1] = aff_mat['AffineTransform_float_3_3'][4] + M[1, 2] = aff_mat['AffineTransform_float_3_3'][5] + M[2, 0] = aff_mat['AffineTransform_float_3_3'][6] + M[2, 1] = aff_mat['AffineTransform_float_3_3'][7] + M[2, 2] = aff_mat['AffineTransform_float_3_3'][8] + M[3, 3] = 1 + M[0:3, 3] = aff_mat['fixed'].T + affines_ras.append(np.multiply(M, ras_trans)) + del M + + new_gt = reorient_bvecs(gt, affines_ras) + + return np.hstack((new_gt.bvecs, new_gt.bvals[..., np.newaxis])) + + +def _nonoverlapping_qspace_samples(prediction_bval, prediction_bvec, + all_bvals, all_bvecs, cutoff): + """Ensure that none of the training samples are too close to the sample to predict. + Parameters + """ + min_bval = min(min(all_bvals), prediction_bval) + all_qvals = np.sqrt(all_bvals - min_bval) + prediction_qval = np.sqrt(prediction_bval - min_bval) + + # Convert q values to percent of maximum qval + max_qval = max(max(all_qvals), prediction_qval) + all_qvals_scaled = all_qvals / max_qval * 100 + prediction_qval_scaled = prediction_qval / max_qval * 100 + scaled_qvecs = all_bvecs * all_qvals_scaled[:, np.newaxis] + scaled_prediction_qvec = prediction_bvec * prediction_qval_scaled + + # Calculate the distance between the sampled qvecs and the prediction qvec + distances = np.linalg.norm(scaled_qvecs - scaled_prediction_qvec, axis=1) + distances_flip = np.linalg.norm(scaled_qvecs + scaled_prediction_qvec, axis=1) + ok_samples = (distances > cutoff) * (distances_flip > cutoff) + + return ok_samples + + +def _rasb_to_bvec_list(in_rasb): + import numpy as np + ras_b_mat = np.genfromtxt(in_rasb, delimiter='\t') + bvec = [vec for vec in ras_b_mat[:, 0:3] if not np.isclose(all(vec), 0)] + return list(bvec) - ras_b_mat = np.genfromtxt(ras_b, delimiter='\t') - gt = gradient_table_from_bvals_bvecs(ras_b_mat[:,3], ras_b_mat[:,0:3], b0_threshold=50) - return reorient_bvecs(gt, affines) \ No newline at end of file +def _rasb_to_bval_floats(in_rasb): + import numpy as np + ras_b_mat = np.genfromtxt(in_rasb, delimiter='\t') + return [float(bval) for bval in ras_b_mat[:, 3] if bval > 0] \ No newline at end of file diff --git a/dmriprep/utils/viz.py b/dmriprep/utils/viz.py new file mode 100644 index 00000000..01914133 --- /dev/null +++ b/dmriprep/utils/viz.py @@ -0,0 +1,131 @@ +import numpy as np +import nibabel as nb +from skimage import measure +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + + +def scaled_mip(img1, img2, img3, axis): + mip1 = img1.max(axis=axis).T + mip2 = img2.max(axis=axis).T + mip3 = img3.max(axis=axis).T + max_obs = max(mip1.max(), mip2.max(), mip3.max()) + vmax = 0.98 * max_obs + return (np.clip(mip1, 0, vmax) / vmax, + np.clip(mip2, 0, vmax) / vmax, + np.clip(mip3, 0, vmax) / vmax) + + +def to_image(fig): + fig.subplots_adjust(hspace=0, left=0, right=1, wspace=0) + fig.canvas.draw() # draw the canvas, cache the renderer + image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return image + + +def before_after_images(orig_file, aligned_file, model_file, imagenum): + fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) + fig.subplots_adjust(hspace=0, left=0, right=1, wspace=0) + for _ax in ax: + _ax.clear() + orig_img = nb.load(orig_file).get_fdata() + aligned_img = nb.load(aligned_file).get_fdata() + model_img = nb.load(model_file).get_fdata() + orig_mip, aligned_mip, target_mip = scaled_mip(orig_img, aligned_img, model_img, 0) + + # Get contours for the orig, aligned images + orig_contours = measure.find_contours(orig_mip, 0.7) + aligned_contours = measure.find_contours(aligned_mip, 0.7) + target_contours = measure.find_contours(target_mip, 0.7) + + orig_contours_low = measure.find_contours(orig_mip, 0.05) + aligned_contours_low = measure.find_contours(aligned_mip, 0.05) + target_contours_low = measure.find_contours(target_mip, 0.05) + + # Plot before + ax[0].imshow(orig_mip, vmax=1., vmin=0, origin="lower", cmap="gray", + interpolation="nearest") + ax[1].imshow(target_mip, vmax=1., vmin=0, origin="lower", cmap="gray", + interpolation="nearest") + ax[0].text(1, 1, "%03d: Before" % imagenum, fontsize=16, color='white') + for contour in target_contours + target_contours_low: + ax[0].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#e7298a") + ax[1].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#e7298a") + for contour in orig_contours + orig_contours_low: + ax[1].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#d95f02") + ax[0].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#d95f02") + for axis in ax: + axis.set_xticks([]) + axis.set_yticks([]) + + before_image = to_image(fig) + + # Plot after + for _ax in ax: + _ax.clear() + ax[0].imshow(aligned_mip, vmax=1., vmin=0, origin="lower", cmap="gray", + interpolation="nearest") + ax[1].imshow(target_mip, vmax=1., vmin=0, origin="lower", cmap="gray", + interpolation="nearest") + ax[0].text(1, 1, "%03d: After" % imagenum, fontsize=16, color='white') + for contour in target_contours + target_contours_low: + ax[0].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#e7298a") + ax[1].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#e7298a") + for contour in aligned_contours + aligned_contours_low: + ax[1].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#d95f02") + ax[0].plot(contour[:, 1], contour[:, 0], linewidth=2, alpha=0.9, color="#d95f02") + for axis in ax: + axis.set_xticks([]) + axis.set_yticks([]) + after_image = to_image(fig) + + return before_image, after_image + + +def _iteration_summary_plot(iters_df, out_file): + iters = list([item[1] for item in iters_df.groupby('iter_num')]) + shift_cols = ["shiftX", "shiftY", "shiftZ"] + rotate_cols = ["rotateX", "rotateY", "rotateZ"] + shifts = np.stack([df[shift_cols] for df in iters], -1) + rotations = np.stack([df[rotate_cols] for df in iters], -1) + + rot_diffs = np.diff(rotations, axis=2).squeeze() + shift_diffs = np.diff(shifts, axis=2).squeeze() + if len(iters) == 2: + rot_diffs = rot_diffs[..., np.newaxis] + shift_diffs = shift_diffs[..., np.newaxis] + + shiftdiff_dfs = [] + rotdiff_dfs = [] + for diffnum, (rot_diff, shift_diff) in enumerate(zip(rot_diffs.T, shift_diffs.T)): + shiftdiff_df = pd.DataFrame(shift_diff.T, columns=shift_cols) + shiftdiff_df['difference_num'] = "%02d" % diffnum + shiftdiff_dfs.append(shiftdiff_df) + + rotdiff_df = pd.DataFrame(rot_diff.T, columns=rotate_cols) + rotdiff_df['difference_num'] = "%02d" % diffnum + rotdiff_dfs.append(rotdiff_df) + + shift_diffs = pd.concat(shiftdiff_dfs, axis=0) + rotate_diffs = pd.concat(rotdiff_dfs, axis=0) + + # Plot shifts + sns.set() + fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) + sns.violinplot(x="variable", y="value", + hue="difference_num", + ax=ax[0], + data=shift_diffs.melt(id_vars=['difference_num'])) + ax[0].set_ylabel("mm") + ax[0].set_title("Shift") + + # Plot rotations + sns.violinplot(x="variable", y="value", + hue="difference_num", + data=rotate_diffs.melt(id_vars=['difference_num'])) + ax[1].set_ylabel("Degrees") + ax[1].set_title("Rotation") + sns.despine(offset=10, trim=True, fig=fig) + fig.savefig(out_file) diff --git a/dmriprep/workflows/dwi/base.py b/dmriprep/workflows/dwi/base.py index 315132aa..c1bdaa92 100644 --- a/dmriprep/workflows/dwi/base.py +++ b/dmriprep/workflows/dwi/base.py @@ -190,3 +190,7 @@ def _get_wf_name(dwi_fname): ".", "_").replace(" ", "").replace("-", "_").replace("dwi", "wf") return name + + +def _list_squeeze(in_list): + return [item[0] for item in in_list] \ No newline at end of file diff --git a/dmriprep/workflows/dwi/emc.py b/dmriprep/workflows/dwi/emc.py new file mode 100644 index 00000000..56e6a85e --- /dev/null +++ b/dmriprep/workflows/dwi/emc.py @@ -0,0 +1,511 @@ +import nipype.pipeline.engine as pe +from nipype.interfaces import utility as niu, afni, ants +from pkg_resources import resource_filename as pkgrf +from dmriprep.workflows.dwi.base import _list_squeeze +from dmriprep.interfaces.bids import DerivativesDataSink +from dmriprep.interfaces.images import SignalPrediction, CalculateCNR, ReorderOutputs, ImageMath, ExtractB0, \ + MatchTransforms, RescaleB0 +from dmriprep.interfaces.reports import IterationSummary, HMCReport +from dmriprep.interfaces.vectors import ReorientVectors, CheckGradientTable, CombineMotions +from dmriprep.utils.images import save_4d_to_3d, save_3d_to_4d, prune_b0s_from_dwis +from dmriprep.utils.vectors import _rasb_to_bvec_list, _rasb_to_bval_floats + + +def linear_alignment_workflow(transform="Rigid", iternum=0, precision="precise"): + from nipype.interfaces import ants + iteration_wf = pe.Workflow(name="iterative_alignment_%03d" % iternum) + input_node_fields = ["image_paths", "template_image", "iteration_num"] + linear_alignment_inputnode = pe.Node( + niu.IdentityInterface(fields=input_node_fields), name='linear_alignment_inputnode') + linear_alignment_inputnode.inputs.iteration_num = iternum + linear_alignment_outputnode = pe.Node( + niu.IdentityInterface(fields=["registered_image_paths", "affine_transforms", + "updated_template"]), name='linear_alignment_outputnode') + ants_settings = pkgrf( + "dmriprep", + "config/emc_{precision}_{transform}.json".format(precision=precision, transform=transform)) + reg = ants.Registration(from_file=ants_settings) + iter_reg = pe.MapNode( + reg, name="reg_%03d" % iternum, iterfield=["moving_image"]) + + # Run the images through antsRegistration + iteration_wf.connect(linear_alignment_inputnode, "image_paths", iter_reg, "moving_image") + iteration_wf.connect(linear_alignment_inputnode, "template_image", iter_reg, "fixed_image") + + # Average the images + averaged_images = pe.Node( + ants.AverageImages(normalize=True, dimension=3), + name="averaged_images") + iteration_wf.connect(iter_reg, "warped_image", averaged_images, "images") + + # Apply the inverse to the average image + transforms_to_list = pe.Node(niu.Merge(1), name="transforms_to_list") + transforms_to_list.inputs.ravel_inputs = True + iteration_wf.connect(iter_reg, "forward_transforms", transforms_to_list, + "in1") + avg_affines = pe.Node(ants.AverageAffineTransform(), name="avg_affine") + avg_affines.inputs.dimension = 3 + avg_affines.inputs.output_affine_transform = "AveragedAffines.mat" + iteration_wf.connect(transforms_to_list, "out", avg_affines, "transforms") + + invert_average = pe.Node(ants.ApplyTransforms(), name="invert_average") + invert_average.inputs.interpolation = "HammingWindowedSinc" + invert_average.inputs.invert_transform_flags = [True] + + avg_to_list = pe.Node(niu.Merge(1), name="to_list") + iteration_wf.connect(avg_affines, "affine_transform", avg_to_list, "in1") + iteration_wf.connect(avg_to_list, "out", invert_average, "transforms") + iteration_wf.connect(averaged_images, "output_average_image", + invert_average, "input_image") + iteration_wf.connect(averaged_images, "output_average_image", + invert_average, "reference_image") + iteration_wf.connect(invert_average, "output_image", linear_alignment_outputnode, + "updated_template") + iteration_wf.connect(iter_reg, "forward_transforms", linear_alignment_outputnode, + "affine_transforms") + iteration_wf.connect(iter_reg, "warped_image", linear_alignment_outputnode, + "registered_image_paths") + + return iteration_wf + + +def init_b0_emc_wf(transform="Rigid", metric="Mattes", num_iters=3, name="b0_emc_wf"): + b0_emc_wf = pe.Workflow(name=name) + + b0_emc_inputnode = pe.Node( + niu.IdentityInterface(fields=['b0_images', 'initial_template']), + name='b0_emc_inputnode') + + b0_emc_outputnode = pe.Node( + niu.IdentityInterface(fields=[ + "final_template", "forward_transforms", "iteration_templates", + "motion_params", "aligned_images"]), + name='b0_emc_outputnode') + + # Iteratively create a template + # Store the registration targets + iter_templates = pe.Node(niu.Merge(num_iters), name="iteration_templates") + b0_emc_wf.connect(b0_emc_inputnode, "initial_template", iter_templates, "in1") + + initial_reg = linear_alignment_workflow( + transform=transform, + precision="coarse", + iternum=0) + b0_emc_wf.connect(b0_emc_inputnode, "initial_template", initial_reg, "linear_alignment_inputnode.template_image") + b0_emc_wf.connect(b0_emc_inputnode, "b0_images", initial_reg, "linear_alignment_inputnode.image_paths") + reg_iters = [initial_reg] + for iternum in range(1, num_iters): + reg_iters.append( + linear_alignment_workflow( + transform=transform, + precision="precise", + iternum=iternum)) + b0_emc_wf.connect(reg_iters[-2], "linear_alignment_outputnode.updated_template", reg_iters[-1], + "linear_alignment_inputnode.template_image") + b0_emc_wf.connect(b0_emc_inputnode, "b0_images", reg_iters[-1], "linear_alignment_inputnode.image_paths") + b0_emc_wf.connect(reg_iters[-1], + "linear_alignment_outputnode.updated_template", iter_templates, "in%d" % (iternum + 1)) + + # Attach to outputs + # The last iteration aligned to the output from the second-to-last + b0_emc_wf.connect(reg_iters[-2], "linear_alignment_outputnode.updated_template", b0_emc_outputnode, + "final_template") + b0_emc_wf.connect(reg_iters[-1], "linear_alignment_outputnode.affine_transforms", b0_emc_outputnode, + "forward_transforms") + b0_emc_wf.connect(reg_iters[-1], "linear_alignment_outputnode.registered_image_paths", b0_emc_outputnode, + "aligned_images") + b0_emc_wf.connect(iter_templates, "out", b0_emc_outputnode, "iteration_templates") + + return b0_emc_wf + + +def init_enhance_and_skullstrip_template_mask_wf(name): + eastm_workflow = pe.Workflow(name=name) + eastm_inputnode = pe.Node(niu.IdentityInterface(fields=['in_file']), + name='eastm_inputnode') + eastm_outputnode = pe.Node(niu.IdentityInterface(fields=[ + 'mask_file', 'skull_stripped_file', 'bias_corrected_file']), name='eastm_outputnode') + + # Truncate intensity values so they're OK for N4 + truncate_values = pe.Node( + ImageMath(dimension=3, + operation="TruncateImageIntensity", + secondary_arg="0.0 0.98 512"), + name="truncate_values") + + # Truncate intensity values for creating a mask + # (there are many high outliers in b=0 images) + truncate_values_for_masking = pe.Node( + ImageMath(dimension=3, + operation="TruncateImageIntensity", + secondary_arg="0.0 0.9 512"), + name="truncate_values_for_masking") + + # N4 will break if any negative values are present. + rescale_image = pe.Node( + ImageMath(dimension=3, + operation="RescaleImage", + secondary_arg="0 1000"), + name="rescale_image" + ) + + # Run N4 normally, force num_threads=1 for stability (images are small, no need for >1) + n4_correct = pe.Node( + ants.N4BiasFieldCorrection( + dimension=3, + n_iterations=[200, 200], + convergence_threshold=1e-6, + bspline_order=3, + bspline_fitting_distance=150, + copy_header=True), + name='n4_correct', n_procs=1) + + # Sharpen the b0 ref + sharpen_image = pe.Node( + ImageMath(dimension=3, + operation="Sharpen"), + name="sharpen_image") + + # Basic mask + initial_mask = pe.Node(afni.Automask(outputtype="NIFTI_GZ"), + name="initial_mask") + + # Fill holes left by Automask + fill_holes = pe.Node( + ImageMath(dimension=3, + operation='FillHoles', + secondary_arg='2'), + name='fill_holes') + + # Dilate before smoothing + dilate_mask = pe.Node( + ImageMath(dimension=3, + operation='MD', + secondary_arg='1'), + name='dilate_mask') + + # Smooth the mask and use it as a weight for N4 + smooth_mask = pe.Node( + ImageMath(dimension=3, + operation='G', + secondary_arg='4'), + name='smooth_mask') + + # Make a "soft" skull-stripped image + apply_mask = pe.Node( + ants.MultiplyImages(dimension=3, output_product_image="SkullStrippedRef.nii.gz"), + name="apply_mask") + + eastm_workflow.connect([ + (eastm_inputnode, truncate_values, [('in_file', 'in_file')]), + (truncate_values, rescale_image, [('out_file', 'in_file')]), + (eastm_inputnode, truncate_values_for_masking, [('in_file', 'in_file')]), + (truncate_values_for_masking, initial_mask, [('out_file', 'in_file')]), + (initial_mask, fill_holes, [('out_file', 'in_file')]), + (fill_holes, dilate_mask, [('out_file', 'in_file')]), + (dilate_mask, smooth_mask, [('out_file', 'in_file')]), + (rescale_image, n4_correct, [('out_file', 'input_image')]), + (smooth_mask, n4_correct, [('out_file', 'weight_image')]), + (n4_correct, sharpen_image, [('output_image', 'in_file')]), + (sharpen_image, eastm_outputnode, [('out_file', 'bias_corrected_file')]), + (sharpen_image, apply_mask, [('out_file', 'first_input')]), + (smooth_mask, apply_mask, [('out_file', 'second_input')]), + (apply_mask, eastm_outputnode, [('output_product_image', 'skull_stripped_file')]), + (fill_holes, eastm_outputnode, [('out_file', 'mask_file')]) + ]) + + return eastm_workflow + + +def init_emc_model_iteration_wf(transform, precision="coarse", name="emc_model_iter0"): + emc_model_iter_workflow = pe.Workflow(name=name) + emc_model_iteration_inputnode = pe.Node(niu.IdentityInterface(fields=['original_dwi_files', 'original_rasb_file', + 'aligned_dwi_files', 'aligned_vectors', + 'b0_median', 'b0_mask', 'b0_indices']), + name='emc_model_iteration_inputnode') + + emc_model_iteration_outputnode = pe.Node( + niu.IdentityInterface( + fields=['emc_transforms', 'aligned_dwis', 'aligned_vectors', + 'predicted_dwis', 'motion_params']), + name='emc_model_iteration_outputnode') + + ants_settings = pkgrf( + "dmriprep", + "config/emc_{precision}_{transform}.json".format(precision=precision, + transform=transform)) + + predict_dwis = pe.MapNode(SignalPrediction(), + iterfield=['bval_to_predict', 'bvec_to_predict'], + name="predict_dwis") + predict_dwis.synchronize = True + + # Register original images to the predicted images + register_to_predicted = pe.MapNode(ants.Registration(from_file=ants_settings), + iterfield=['fixed_image', 'moving_image'], + name='register_to_predicted') + register_to_predicted.synchronize = True + + # Apply new transforms to vectors + post_vector_transforms = pe.Node(ReorientVectors(), name="post_vector_transforms") + + # Summarize the motion + calculate_motion = pe.Node(CombineMotions(), name="calculate_motion") + + emc_model_iter_workflow.connect([ + # Send inputs to DWI prediction + (emc_model_iteration_inputnode, predict_dwis, [('aligned_dwi_files', 'aligned_dwi_files'), + ('aligned_vectors', 'aligned_vectors'), + ('b0_indices', 'b0_indices'), + ('b0_median', 'b0_median'), + ('b0_mask', 'b0_mask'), + (('aligned_vectors', _rasb_to_bvec_list), 'bvec_to_predict'), + (('aligned_vectors', _rasb_to_bval_floats), 'bval_to_predict')]), + (predict_dwis, register_to_predicted, [('predicted_image', 'fixed_image')]), + (emc_model_iteration_inputnode, register_to_predicted, [ + ('original_dwi_files', 'moving_image'), + ('b0_mask', 'fixed_image_mask')]), + (register_to_predicted, calculate_motion, [ + (('forward_transforms', _list_squeeze), 'transform_files')]), + (emc_model_iteration_inputnode, calculate_motion, [('original_dwi_files', 'source_files'), + ('b0_median', 'ref_file')]), + (calculate_motion, emc_model_iteration_outputnode, [('motion_file', 'motion_params')]), + (register_to_predicted, post_vector_transforms, [ + (('forward_transforms', _list_squeeze), 'affines')]), + (emc_model_iteration_inputnode, post_vector_transforms, [('original_rasb_file', 'rasb_file')]), + (predict_dwis, emc_model_iteration_outputnode, [('predicted_image', 'predicted_dwis')]), + (post_vector_transforms, emc_model_iteration_outputnode, [('out_rasb', 'aligned_vectors')]), + (register_to_predicted, emc_model_iteration_outputnode, [('warped_image', 'aligned_dwis'), + ('forward_transforms', 'emc_transforms')]) + ]) + + return emc_model_iter_workflow + + +def init_dwi_model_emc_wf(transform, num_iters=2, name='dwi_model_emc_wf'): + dwi_model_emc_wf_workflow = pe.Workflow(name=name) + inputnode = pe.Node( + niu.IdentityInterface( + fields=['original_dwi_files', 'original_rasb_file', 'aligned_dwi_files', 'b0_indices', + 'initial_transforms', 'aligned_vectors', 'b0_median', 'b0_mask', 'warped_b0_images']), + name='dwi_model_emc_inputnode') + + outputnode = pe.Node( + niu.IdentityInterface( + fields=['emc_transforms', 'model_predicted_images', 'cnr_image', + 'optimization_data']), + name='dwi_model_emc_outputnode') + + # Start building and connecting the model iterations + initial_model_iteration = init_emc_model_iteration_wf( + transform, precision="coarse", name="initial_model_iteration") + + # Collect motion estimates across iterations + collect_motion_params = pe.Node(niu.Merge(num_iters), + name="collect_motion_params") + + dwi_model_emc_wf_workflow.connect([ + # Connect the first iteration + (inputnode, initial_model_iteration, [ + ('original_dwi_files', 'emc_model_iteration_inputnode.original_dwi_files'), + ('b0_median', 'emc_model_iteration_inputnode.b0_median'), + ('b0_mask', 'emc_model_iteration_inputnode.b0_mask'), + ('b0_indices', 'emc_model_iteration_inputnode.b0_indices'), + ('aligned_vectors', 'emc_model_iteration_inputnode.aligned_vectors'), + ('aligned_dwi_files', 'emc_model_iteration_inputnode.aligned_dwi_files')]), + # (initial_model_iteration, collect_motion_params, [ + # ('emc_model_iteration_outputnode.motion_params', 'in1')]) + ]) + + model_iterations = [initial_model_iteration] + for iteration_num in range(num_iters - 1): + iteration_name = 'HMC_iteration%03d' % (iteration_num + 1) + motion_key = 'in%d' % (iteration_num + 2) + model_iterations.append( + init_emc_model_iteration_wf(transform=transform, + precision="precise", + name=iteration_name) + ) + dwi_model_emc_wf_workflow.connect([ + (model_iterations[-2], model_iterations[-1], [ + ('emc_model_iteration_outputnode.aligned_dwis', + 'emc_model_iteration_inputnode.aligned_dwi_files'), + ('emc_model_iteration_outputnode.aligned_vectors', + 'emc_model_iteration_inputnode.aligned_vectors')]), + (inputnode, model_iterations[-1], [ + ('b0_mask', 'emc_model_iteration_inputnode.b0_mask'), + ('b0_indices', 'emc_model_iteration_inputnode.b0_indices'), + ('original_dwi_files', 'emc_model_iteration_inputnode.original_dwi_files'), + ('original_rasb_file', 'emc_model_iteration_inputnode.original_rasb_file'), + ('b0_median', 'emc_model_iteration_inputnode.b0_median')]), + (model_iterations[-1], collect_motion_params, [ + ('emc_model_iteration_outputnode.motion_params', motion_key)]) + ]) + + # Return to the original, b0-interspersed ordering + reorder_dwi_xforms = pe.Node(ReorderOutputs(), name='reorder_dwi_xforms') + + # Create a report: + emc_report = pe.Node(HMCReport(), name='emc_report') + ds_report_emc_gif = pe.Node( + DerivativesDataSink(suffix="emc_animation"), name='ds_report_emc_gif', + mem_gb=1, run_without_submitting=True) + + calculate_cnr = pe.Node(CalculateCNR(), name='calculate_cnr') + + # if num_iters > 1: + # summarize_iterations = pe.Node(IterationSummary(), name='summarize_iterations') + # ds_report_iteration_plot = pe.Node( + # DerivativesDataSink(suffix="emc_iterdata"), name='ds_report_iteration_plot', + # mem_gb=0.1, run_without_submitting=True) + # dwi_model_emc_wf_workflow.connect([ + # (collect_motion_params, summarize_iterations, [ + # ('out', 'collected_motion_files')]), + # (summarize_iterations, ds_report_iteration_plot, [ + # ('plot_file', 'in_file')]), + # (summarize_iterations, outputnode, [ + # ('iteration_summary_file', 'optimization_data')]), + # (summarize_iterations, emc_report, [ + # ('iteration_summary_file', 'iteration_summary')]) + # ]) + + dwi_model_emc_wf_workflow.connect([ + (model_iterations[-1], reorder_dwi_xforms, [ + ('emc_model_iteration_outputnode.emc_transforms', 'model_based_transforms'), + ('emc_model_iteration_outputnode.predicted_dwis', 'model_predicted_images'), + ('emc_model_iteration_outputnode.aligned_dwis', 'warped_dwi_images')]), + (inputnode, reorder_dwi_xforms, [ + ('b0_median', 'b0_median'), + ('warped_b0_images', 'warped_b0_images'), + ('b0_indices', 'b0_indices'), + ('initial_transforms', 'initial_transforms')]), + (reorder_dwi_xforms, outputnode, [ + ('emc_warped_images', 'aligned_dwis'), + ('full_transforms', 'emc_transforms'), + ('full_predicted_dwi_series', 'model_predicted_images')]), + (inputnode, emc_report, [('original_dwi_files', 'original_images')]), + (reorder_dwi_xforms, calculate_cnr, [ + ('emc_warped_images', 'emc_warped_images'), + ('full_predicted_dwi_series', 'predicted_images')]), + (inputnode, calculate_cnr, [('b0_mask', 'mask_image')]), + (calculate_cnr, outputnode, [('cnr_image', 'cnr_image')]), + (reorder_dwi_xforms, emc_report, [ + ('full_predicted_dwi_series', 'model_predicted_images'), + ('emc_warped_images', 'registered_images')]), + (emc_report, ds_report_emc_gif, [('plot_file', 'in_file')]) + ]) + + return dwi_model_emc_wf_workflow + + +def init_emc_wf(name, mem_gb=3, omp_nthreads=8): + import_list = ["import warnings", "warnings.filterwarnings(\"ignore\")", "import sys", "import os", + "import numpy as np", "import networkx as nx", "import nibabel as nb", + "from nipype.utils.filemanip import fname_presuffix"] + + emc_wf = pe.Workflow(name=name) + + meta_inputnode = pe.Node(niu.IdentityInterface(fields=['dwi_file', 'in_bval', 'in_bvec', 'b0_template']), + name='meta_inputnode') + + vectors_node = pe.Node(CheckGradientTable(), name='emc_vectors_node') + + extract_b0s_node = pe.Node(ExtractB0(), name="extract_b0_node") + + split_b0s_node = pe.Node(niu.Function(input_names=['in_file'], output_names=['out_files'], + function=save_4d_to_3d, imports=import_list), + name="split_b0s_node") + + prune_b0s_from_dwis_node = pe.Node(niu.Function(input_names=['in_files', 'b0_ixs'], output_names=['out_files'], + function=prune_b0s_from_dwis, imports=import_list), + name="prune_b0s_from_dwis_node") + + split_dwis_node = pe.Node(niu.Function(input_names=['in_file'], output_names=['out_files'], + function=save_4d_to_3d, imports=import_list), + name="split_dwis_node") + + merge_b0s_node = pe.Node(niu.Function(input_names=['in_files'], output_names=['out_file'], + function=save_3d_to_4d, imports=import_list), + name="merge_b0s_node") + + match_transforms_node = pe.Node(MatchTransforms(), name="match_transforms_node") + + b0_emc_wf = init_b0_emc_wf(transform="Rigid") + + eastm_wf = init_enhance_and_skullstrip_template_mask_wf(name='enhance_and_skullstrip_template_mask_wf') + + # Initialize with the transforms provided + b0_based_image_transforms = pe.MapNode(ants.ApplyTransforms(interpolation="BSpline"), + iterfield=['input_image', 'transforms'], + name="b0_based_image_transforms") + # Rotate vectors + b0_based_vector_transforms = pe.Node(ReorientVectors(), name="b0_based_vector_transforms") + + # Grab the median of the aligned B0 images + b0_median = pe.Node(RescaleB0(), name='b0_median') + + # Do model-based motion correction + dwi_model_emc_wf = init_dwi_model_emc_wf(transform='Affine', num_iters=2) + + # Warp the modeled images into non-motion-corrected space + uncorrect_model_images = pe.MapNode( + ants.ApplyTransforms(invert_transform_flags=[True], + interpolation='LanczosWindowedSinc'), + iterfield=['input_image', 'reference_image', 'transforms'], + name='uncorrect_model_images') + + meta_outputnode = pe.Node( + niu.IdentityInterface( + fields=["final_template", "forward_transforms", "noise_free_dwis", + "cnr_image", "optimization_data"]), + name='meta_outputnode') + + emc_wf.connect([ + # Create initial transforms to apply to bias-reduced B0 template + (meta_inputnode, vectors_node, [('dwi_file', 'dwi_file'), + ('in_bval', 'in_bval'), + ('in_bvec', 'in_bvec')]), + (meta_inputnode, extract_b0s_node, [('dwi_file', 'in_file')]), + (vectors_node, extract_b0s_node, [('b0_ixs', 'b0_ixs')]), + (extract_b0s_node, split_b0s_node, [('out_file', 'in_file')]), + (meta_inputnode, split_dwis_node, [('dwi_file', 'in_file')]), + (split_b0s_node, b0_emc_wf, [('out_files', 'b0_emc_inputnode.b0_images')]), + (meta_inputnode, b0_emc_wf, [('b0_template', 'b0_emc_inputnode.initial_template')]), + (b0_emc_wf, match_transforms_node, [(('b0_emc_outputnode.forward_transforms', _list_squeeze), 'transforms')]), + (vectors_node, match_transforms_node, [('b0_ixs', 'b0_indices')]), + (split_dwis_node, match_transforms_node, [('out_files', 'dwi_files')]), + (b0_emc_wf, eastm_wf, + [('b0_emc_outputnode.final_template', 'eastm_inputnode.in_file')]), + (match_transforms_node, b0_based_vector_transforms, [('transforms', 'affines')]), + (vectors_node, b0_based_vector_transforms, [('out_rasb', 'rasb_file')]), + (split_dwis_node, b0_based_image_transforms, [('out_files', 'input_image')]), + (match_transforms_node, b0_based_image_transforms, [('transforms', 'transforms')]), + (b0_emc_wf, merge_b0s_node, [('b0_emc_outputnode.aligned_images', 'in_files')]), + (merge_b0s_node, b0_median, [('out_file', 'in_file')]), + (eastm_wf, b0_median, [('eastm_outputnode.mask_file', 'mask_file')]), + (b0_median, b0_based_image_transforms, [('out_ref', 'reference_image')]), + + # Perform signal prediction from vectors + (vectors_node, prune_b0s_from_dwis_node, [('b0_ixs', 'b0_ixs')]), + (split_dwis_node, prune_b0s_from_dwis_node, [('out_files', 'in_files')]), + (prune_b0s_from_dwis_node, dwi_model_emc_wf, [('out_files', 'dwi_model_emc_inputnode.original_dwi_files')]), + (vectors_node, dwi_model_emc_wf, [('out_rasb', 'dwi_model_emc_inputnode.original_rasb_file')]), + (b0_based_image_transforms, dwi_model_emc_wf, [('output_image', 'dwi_model_emc_inputnode.aligned_dwi_files')]), + (b0_based_vector_transforms, dwi_model_emc_wf, [('out_rasb', 'dwi_model_emc_inputnode.aligned_vectors')]), + (b0_median, dwi_model_emc_wf, [('out_ref', 'dwi_model_emc_inputnode.b0_median')]), + (eastm_wf, dwi_model_emc_wf, + [('eastm_outputnode.mask_file', 'dwi_model_emc_inputnode.b0_mask')]), + (vectors_node, dwi_model_emc_wf, [('b0_ixs', 'dwi_model_emc_inputnode.b0_indices')]), + (match_transforms_node, dwi_model_emc_wf, [('transforms', 'dwi_model_emc_inputnode.initial_transforms')]), + (b0_emc_wf, dwi_model_emc_wf, + [('b0_emc_outputnode.aligned_images', 'dwi_model_emc_inputnode.warped_b0_images')]), + (b0_emc_wf, meta_outputnode, [('b0_emc_outputnode.final_template', 'final_template')]), + (dwi_model_emc_wf, meta_outputnode, [('dwi_model_emc_outputnode.emc_transforms', 'forward_transforms'), + ('dwi_model_emc_outputnode.optimization_data', 'optimization_data'), + ('dwi_model_emc_outputnode.cnr_image', 'cnr_image')]), + (b0_emc_wf, uncorrect_model_images, [('b0_emc_outputnode.final_template', 'reference_image')]), + (dwi_model_emc_wf, uncorrect_model_images, [('dwi_model_emc_outputnode.emc_transforms', 'transforms')]), + (split_dwis_node, uncorrect_model_images, [('out_files', 'input_image')]), + (uncorrect_model_images, meta_outputnode, [('output_image', 'noise_free_dwis')]) + ]) + return emc_wf