From fe6cfe9d79b6ecb957ba39e566bc32729d4acb50 Mon Sep 17 00:00:00 2001 From: dPys Date: Tue, 4 Feb 2020 01:03:18 -0600 Subject: [PATCH] Write out final noise-free images to 4d file Clean up motion plotting --- dmriprep/config/emc_coarse_Affine.json | 31 +- dmriprep/config/emc_coarse_Rigid.json | 29 +- dmriprep/config/emc_precise_Affine.json | 31 +- dmriprep/config/emc_precise_Rigid.json | 29 +- dmriprep/interfaces/bids.py | 139 --- dmriprep/interfaces/images.py | 251 +++-- dmriprep/interfaces/register.py | 147 +++ dmriprep/interfaces/reports.py | 10 +- dmriprep/interfaces/vectors.py | 145 +-- dmriprep/utils/images.py | 97 +- dmriprep/utils/register.py | 436 +++++++++ dmriprep/utils/vectors.py | 239 ++--- dmriprep/workflows/dwi/emc.py | 1155 ++++++++++++++++------- 13 files changed, 1851 insertions(+), 888 deletions(-) delete mode 100644 dmriprep/interfaces/bids.py create mode 100644 dmriprep/interfaces/register.py create mode 100644 dmriprep/utils/register.py diff --git a/dmriprep/config/emc_coarse_Affine.json b/dmriprep/config/emc_coarse_Affine.json index 5602e499..f7663871 100644 --- a/dmriprep/config/emc_coarse_Affine.json +++ b/dmriprep/config/emc_coarse_Affine.json @@ -1,25 +1,8 @@ { - "dimension": 3, - "float": true, - "winsorize_lower_quantile": 0.002, - "winsorize_upper_quantile": 0.998, - "collapse_output_transforms": true, - "write_composite_transform": false, - "use_histogram_matching": [ false, false ], - "use_estimate_learning_rate_once": [ true, true ], - "transforms": [ "Rigid", "Affine" ], - "number_of_iterations": [ [ 100, 100 ], [ 100 ] ], - "output_warped_image": true, - "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], - "convergence_threshold": [ 1e-06, 1e-06 ], - "convergence_window_size": [ 20, 20 ], - "metric": [ "Mattes", "Mattes" ], - "sampling_percentage": [ 0.15, 0.2 ], - "sampling_strategy": [ "Random", "Random" ], - "smoothing_sigmas": [ [ 8.0, 2.0 ], [ 2.0 ] ], - "sigma_units": [ "mm", "mm" ], - "metric_weight": [ 1.0, 1.0 ], - "shrink_factors": [ [ 2, 1 ], [ 1 ] ], - "radius_or_number_of_bins": [ 48, 48 ], - "interpolation": "BSpline" -} + "level_iters": [1000, 100], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} \ No newline at end of file diff --git a/dmriprep/config/emc_coarse_Rigid.json b/dmriprep/config/emc_coarse_Rigid.json index f35c7a9c..5d23fba6 100644 --- a/dmriprep/config/emc_coarse_Rigid.json +++ b/dmriprep/config/emc_coarse_Rigid.json @@ -1,25 +1,8 @@ { - "dimension": 3, - "float": true, - "winsorize_lower_quantile": 0.002, - "winsorize_upper_quantile": 0.998, - "collapse_output_transforms": true, - "write_composite_transform": false, - "use_histogram_matching": [ false ], - "use_estimate_learning_rate_once": [ true ], - "transforms": [ "Rigid" ], - "number_of_iterations": [ [ 100, 100 ] ], - "output_warped_image": true, - "transform_parameters": [ [ 0.2 ] ], - "convergence_threshold": [ 1e-06 ], - "convergence_window_size": [ 20 ], - "metric": [ "Mattes" ], - "sampling_percentage": [ 0.15 ], - "sampling_strategy": [ "Random" ], - "smoothing_sigmas": [ [ 8.0, 2.0 ] ], - "sigma_units": [ "mm"], - "metric_weight": [ 1.0 ], - "shrink_factors": [ [ 2, 1 ] ], - "radius_or_number_of_bins": [ 48 ], - "interpolation": "BSpline" + "level_iters": [100, 100], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 } diff --git a/dmriprep/config/emc_precise_Affine.json b/dmriprep/config/emc_precise_Affine.json index 924a23a4..99fbb1e2 100644 --- a/dmriprep/config/emc_precise_Affine.json +++ b/dmriprep/config/emc_precise_Affine.json @@ -1,25 +1,8 @@ { - "dimension": 3, - "float": true, - "winsorize_lower_quantile": 0.002, - "winsorize_upper_quantile": 0.998, - "collapse_output_transforms": true, - "write_composite_transform": false, - "use_histogram_matching": [ false, false ], - "use_estimate_learning_rate_once": [ true, true ], - "transforms": [ "Rigid", "Affine" ], - "number_of_iterations": [ [ 1000, 1000 ], [ 1000 ] ], - "output_warped_image": true, - "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], - "convergence_threshold": [ 1e-08, 1e-08 ], - "convergence_window_size": [ 20, 20 ], - "metric": [ "Mattes", "Mattes" ], - "sampling_percentage": [ 0.15, 0.2 ], - "sampling_strategy": [ "Random", "Random" ], - "smoothing_sigmas": [ [ 8.0, 2.0 ], [ 2.0 ] ], - "sigma_units": [ "mm", "mm" ], - "metric_weight": [ 1.0, 1.0 ], - "shrink_factors": [ [ 2, 1 ], [ 1 ] ], - "radius_or_number_of_bins": [ 48, 48 ], - "interpolation": "BSpline" -} + "level_iters": [1000, 1000], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} \ No newline at end of file diff --git a/dmriprep/config/emc_precise_Rigid.json b/dmriprep/config/emc_precise_Rigid.json index 78127352..e62cd12a 100644 --- a/dmriprep/config/emc_precise_Rigid.json +++ b/dmriprep/config/emc_precise_Rigid.json @@ -1,25 +1,8 @@ { - "dimension": 3, - "float": true, - "winsorize_lower_quantile": 0.002, - "winsorize_upper_quantile": 0.998, - "collapse_output_transforms": true, - "write_composite_transform": false, - "use_histogram_matching": [ false ], - "use_estimate_learning_rate_once": [ true ], - "transforms": [ "Rigid" ], - "number_of_iterations": [ [ 1000, 1000 ] ], - "output_warped_image": true, - "transform_parameters": [ [ 0.2 ], [ 0.15 ] ], - "convergence_threshold": [ 1e-08, 1e-08 ], - "convergence_window_size": [ 20, 20 ], - "metric": [ "Mattes" ], - "sampling_percentage": [ 0.15 ], - "sampling_strategy": [ "Random" ], - "smoothing_sigmas": [ [ 8.0, 2.0 ] ], - "sigma_units": [ "mm" ], - "metric_weight": [ 1.0 ], - "shrink_factors": [ [ 2, 1 ] ], - "radius_or_number_of_bins": [ 48 ], - "interpolation": "BSpline" + "level_iters": [1000, 1000], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 } diff --git a/dmriprep/interfaces/bids.py b/dmriprep/interfaces/bids.py deleted file mode 100644 index f64c4484..00000000 --- a/dmriprep/interfaces/bids.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -import re -from pathlib import Path -from nipype.interfaces.base import ( - traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File, isdefined, - InputMultiObject, OutputMultiPath, OutputMultiObject, -) -from ..utils.bids import splitext as _splitext, _copy_any - -__all__ = ['BIDS_NAME'] - -BIDS_NAME = re.compile( - r'^(.*\/)?(?Psub-[a-zA-Z0-9]+)(_(?Pses-[a-zA-Z0-9]+))?' - '(_(?Ptask-[a-zA-Z0-9]+))?(_(?Pacq-[a-zA-Z0-9]+))?' - '(_(?Prec-[a-zA-Z0-9]+))?(_(?Prun-[a-zA-Z0-9]+))?') - - -class DerivativesDataSinkInputSpec(BaseInterfaceInputSpec): - base_directory = traits.Directory( - desc='Path to the base directory for storing data.') - in_file = InputMultiObject(File(exists=True), mandatory=True, - desc='the object to be saved') - source_file = File(exists=False, mandatory=True, desc='the original file') - prefix = traits.Str(mandatory=False, desc='prefix for output files') - space = traits.Str('', usedefault=True, desc='Label for space field') - desc = traits.Str('', usedefault=True, desc='Label for description field') - suffix = traits.Str('', usedefault=True, desc='suffix appended to source_file') - keep_dtype = traits.Bool(False, usedefault=True, desc='keep datatype suffix') - extra_values = traits.List(traits.Str) - compress = traits.Bool(desc="force compression (True) or uncompression (False)" - " of the output file (default: same as input)") - extension = traits.Str() - - -class DerivativesDataSinkOutputSpec(TraitedSpec): - out_file = OutputMultiObject(File(exists=True, desc='written file path')) - compression = OutputMultiPath( - traits.Bool, desc='whether ``in_file`` was compressed/uncompressed ' - 'or `it was copied directly.') - - -class DerivativesDataSink(SimpleInterface): - """ - Saves the `in_file` into a BIDS-Derivatives folder provided - by `base_directory`, given the input reference `source_file`. - >>> from pathlib import Path - >>> import tempfile - >>> from dmriprep.utils.bids import collect_data - >>> tmpdir = Path(tempfile.mkdtemp()) - >>> tmpfile = tmpdir / 'a_temp_file.nii.gz' - >>> tmpfile.open('w').close() # "touch" the file - >>> dsink = DerivativesDataSink(base_directory=str(tmpdir)) - >>> dsink.inputs.in_file = str(tmpfile) - >>> dsink.inputs.source_file = collect_data('ds114', '01')[0]['t1w'][0] - >>> dsink.inputs.keep_dtype = True - >>> dsink.inputs.suffix = 'target-mni' - >>> res = dsink.run() - >>> res.outputs.out_file # doctest: +ELLIPSIS - '.../dmriprep/sub-01/ses-retest/anat/sub-01_ses-retest_target-mni_T1w.nii.gz' - >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' - >>> bids_dir.mkdir(parents=True, exist_ok=True) - >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' - >>> tricky_source.open('w').close() - >>> dsink = DerivativesDataSink(base_directory=str(tmpdir)) - >>> dsink.inputs.in_file = str(tmpfile) - >>> dsink.inputs.source_file = str(tricky_source) - >>> dsink.inputs.keep_dtype = True - >>> dsink.inputs.desc = 'preproc' - >>> res = dsink.run() - >>> res.outputs.out_file # doctest: +ELLIPSIS - '.../dmriprep/sub-02/ses-noanat/func/sub-02_ses-noanat_task-rest_run-01_\ -desc-preproc_bold.nii.gz' - """ - input_spec = DerivativesDataSinkInputSpec - output_spec = DerivativesDataSinkOutputSpec - out_path_base = "dmriprep" - _always_run = True - - def __init__(self, out_path_base=None, **inputs): - super(DerivativesDataSink, self).__init__(**inputs) - self._results['out_file'] = [] - if out_path_base: - self.out_path_base = out_path_base - - def _run_interface(self, runtime): - src_fname, _ = _splitext(self.inputs.source_file) - src_fname, dtype = src_fname.rsplit('_', 1) - _, ext = _splitext(self.inputs.in_file[0]) - if self.inputs.compress is True and not ext.endswith('.gz'): - ext += '.gz' - elif self.inputs.compress is False and ext.endswith('.gz'): - ext = ext[:-3] - - m = BIDS_NAME.search(src_fname) - - mod = os.path.basename(os.path.dirname(self.inputs.source_file)) - - base_directory = runtime.cwd - if isdefined(self.inputs.base_directory): - base_directory = str(self.inputs.base_directory) - - out_path = '{}/{subject_id}'.format(self.out_path_base, **m.groupdict()) - if m.groupdict().get('session_id') is not None: - out_path += '/{session_id}'.format(**m.groupdict()) - out_path += '/{}'.format(mod) - - out_path = os.path.join(base_directory, out_path) - - os.makedirs(out_path, exist_ok=True) - - if isdefined(self.inputs.prefix): - base_fname = os.path.join(out_path, self.inputs.prefix) - else: - base_fname = os.path.join(out_path, src_fname) - - formatstr = '{bname}{space}{desc}{suffix}{dtype}{ext}' - if len(self.inputs.in_file) > 1 and not isdefined(self.inputs.extra_values): - formatstr = '{bname}{space}{desc}{suffix}{i:04d}{dtype}{ext}' - - space = '_space-{}'.format(self.inputs.space) if self.inputs.space else '' - desc = '_desc-{}'.format(self.inputs.desc) if self.inputs.desc else '' - suffix = '_{}'.format(self.inputs.suffix) if self.inputs.suffix else '' - dtype = '' if not self.inputs.keep_dtype else ('_%s' % dtype) - - self._results['compression'] = [] - for i, fname in enumerate(self.inputs.in_file): - out_file = formatstr.format( - bname=base_fname, - space=space, - desc=desc, - suffix=suffix, - i=i, - dtype=dtype, - ext=ext) - if isdefined(self.inputs.extra_values): - out_file = out_file.format(extra_value=self.inputs.extra_values[i]) - self._results['out_file'].append(out_file) - self._results['compression'].append(_copy_any(fname, out_file)) - return runtime \ No newline at end of file diff --git a/dmriprep/interfaces/images.py b/dmriprep/interfaces/images.py index e17f5b73..cfea8d7a 100644 --- a/dmriprep/interfaces/images.py +++ b/dmriprep/interfaces/images.py @@ -6,24 +6,37 @@ from nipype.interfaces import ants from nipype.utils.filemanip import split_filename from nipype.interfaces.base import ( - traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File, isdefined, - InputMultiObject, OutputMultiObject, CommandLine + traits, + TraitedSpec, + BaseInterfaceInputSpec, + SimpleInterface, + File, + isdefined, + InputMultiObject, + OutputMultiObject, + CommandLine, +) +from ..utils.images import ( + extract_b0, + rescale_b0, + median, + quick_load_images, + match_transforms, + prune_b0s_from_dwis, ) -from ..utils.images import extract_b0, rescale_b0, median, quick_load_images, match_transforms, prune_b0s_from_dwis from ..utils.vectors import _nonoverlapping_qspace_samples -LOGGER = logging.getLogger('nipype.interface') +LOGGER = logging.getLogger("nipype.interface") class _ExtractB0InputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='dwi file') - b0_ixs = traits.List(traits.Int, mandatory=True, - desc='Index of b0s') + in_file = File(exists=True, mandatory=True, desc="dwi file") + b0_ixs = traits.List(traits.Int, mandatory=True, desc="Index of b0s") class _ExtractB0OutputSpec(TraitedSpec): - out_file = File(exists=True, desc='b0 file') + out_file = File(exists=True, desc="b0 file") class ExtractB0(SimpleInterface): @@ -44,21 +57,20 @@ class ExtractB0(SimpleInterface): output_spec = _ExtractB0OutputSpec def _run_interface(self, runtime): - self._results['out_file'] = extract_b0( - self.inputs.in_file, - self.inputs.b0_ixs, - newpath=runtime.cwd) + self._results["out_file"] = extract_b0( + self.inputs.in_file, self.inputs.b0_ixs, newpath=runtime.cwd + ) return runtime class _RescaleB0InputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='b0s file') - mask_file = File(exists=True, mandatory=True, desc='mask file') + in_file = File(exists=True, mandatory=True, desc="b0s file") + mask_file = File(exists=True, mandatory=True, desc="mask file") class _RescaleB0OutputSpec(TraitedSpec): - out_ref = File(exists=True, desc='One median b0 file') - out_b0s = File(exists=True, desc='series of rescaled b0 volumes') + out_ref = File(exists=True, desc="One average b0 file") + out_b0s = File(exists=True, desc="series of rescaled b0 volumes") class RescaleB0(SimpleInterface): @@ -79,15 +91,10 @@ class RescaleB0(SimpleInterface): output_spec = _RescaleB0OutputSpec def _run_interface(self, runtime): - self._results['out_b0s'] = rescale_b0( - self.inputs.in_file, - self.inputs.mask_file, - newpath=runtime.cwd - ) - self._results['out_ref'] = median( - self._results['out_b0s'], - newpath=runtime.cwd + self._results["out_b0s"] = rescale_b0( + self.inputs.in_file, self.inputs.mask_file, newpath=runtime.cwd ) + self._results["out_ref"] = median(self._results["out_b0s"], newpath=runtime.cwd) return runtime @@ -106,9 +113,9 @@ class MatchTransforms(SimpleInterface): output_spec = MatchTransformsOutputSpec def _run_interface(self, runtime): - self._results['transforms'] = match_transforms(self.inputs.dwi_files, - self.inputs.transforms, - self.inputs.b0_indices) + self._results["transforms"] = match_transforms( + self.inputs.dwi_files, self.inputs.transforms, self.inputs.b0_indices + ) return runtime @@ -117,7 +124,7 @@ class N3BiasFieldCorrection(ants.N4BiasFieldCorrection): class ImageMathInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, position=3, argstr='%s') + in_file = File(exists=True, mandatory=True, position=3, argstr="%s") dimension = traits.Enum(3, 2, 4, usedefault=True, argstr="%d", position=0) out_file = File(argstr="%s", genfile=True, position=1) operation = traits.Str(argstr="%s", position=2) @@ -132,10 +139,10 @@ class ImageMathOutputSpec(TraitedSpec): class ImageMath(CommandLine): input_spec = ImageMathInputSpec output_spec = ImageMathOutputSpec - _cmd = 'ImageMath' + _cmd = "ImageMath" def _gen_filename(self, name): - if name == 'out_file': + if name == "out_file": output = self.inputs.out_file if not isdefined(output): _, fname, ext = split_filename(self.inputs.in_file) @@ -145,7 +152,7 @@ def _gen_filename(self, name): def _list_outputs(self): outputs = self.output_spec().get() - outputs['out_file'] = os.path.abspath(self._gen_filename('out_file')) + outputs["out_file"] = os.path.abspath(self._gen_filename("out_file")) return outputs @@ -158,6 +165,8 @@ class SignalPredictionInputSpec(BaseInterfaceInputSpec): bval_to_predict = traits.Float() minimal_q_distance = traits.Float(2.0, usedefault=True) b0_indices = traits.List() + prune_b0s = traits.Bool(False, usedefault=True) + model_name = traits.Str(default_value="sfm", userdefault=True) class SignalPredictionOutputSpec(TraitedSpec): @@ -167,13 +176,16 @@ class SignalPredictionOutputSpec(TraitedSpec): class SignalPrediction(SimpleInterface): """ """ + input_spec = SignalPredictionInputSpec output_spec = SignalPredictionOutputSpec - def _run_interface(self, runtime, model_name='tensor'): + def _run_interface(self, runtime): import warnings + warnings.filterwarnings("ignore") from dipy.core.gradients import gradient_table_from_bvals_bvecs + pred_vec = self.inputs.bvec_to_predict pred_val = self.inputs.bval_to_predict @@ -181,19 +193,30 @@ def _run_interface(self, runtime, model_name='tensor'): mask_img = nb.load(self.inputs.b0_mask) mask_array = mask_img.get_data() > 1e-6 - all_images = prune_b0s_from_dwis(self.inputs.aligned_dwi_files, self.inputs.b0_indices) + if self.inputs.prune_b0s is True: + all_images = prune_b0s_from_dwis( + self.inputs.aligned_dwi_files, self.inputs.b0_indices + ) + else: + all_images = self.inputs.aligned_dwi_files - # Load the vectors - ras_b_mat = np.genfromtxt(self.inputs.aligned_vectors, delimiter='\t') - all_bvecs = np.row_stack([np.zeros(3), np.delete(ras_b_mat[:, 0:3], self.inputs.b0_indices, axis=0)]) - all_bvals = np.concatenate([np.zeros(1), np.delete(ras_b_mat[:, 3], self.inputs.b0_indices)]) + # Load the vectors + ras_b_mat = np.genfromtxt(self.inputs.aligned_vectors, delimiter="\t") + all_bvecs = np.row_stack( + [np.zeros(3), np.delete(ras_b_mat[:, 0:3], self.inputs.b0_indices, axis=0)] + ) + all_bvals = np.concatenate( + [np.zeros(1), np.delete(ras_b_mat[:, 3], self.inputs.b0_indices)] + ) # Which sample points are too close to the one we want to predict? training_mask = _nonoverlapping_qspace_samples( - pred_val, pred_vec, all_bvals, all_bvecs, self.inputs.minimal_q_distance) + pred_val, pred_vec, all_bvals, all_bvecs, self.inputs.minimal_q_distance + ) training_indices = np.flatnonzero(training_mask[1:]) training_image_paths = [self.inputs.b0_median] + [ - all_images[idx] for idx in training_indices] + all_images[idx] for idx in training_indices + ] training_bvecs = all_bvecs[training_mask] training_bvals = all_bvals[training_mask] # print("Training with volumes: {}".format(str(training_indices))) @@ -202,7 +225,9 @@ def _run_interface(self, runtime, model_name='tensor'): training_data = quick_load_images(training_image_paths) # Build gradient table object - training_gtab = gradient_table_from_bvals_bvecs(training_bvals, training_bvecs, b0_threshold=0) + training_gtab = gradient_table_from_bvals_bvecs( + training_bvals, training_bvecs, b0_threshold=0 + ) # Checked shelledness if len(np.unique(training_gtab.bvals)) > 2: @@ -211,53 +236,131 @@ def _run_interface(self, runtime, model_name='tensor'): is_shelled = False # Get the vector for the desired coordinate - prediction_gtab = gradient_table_from_bvals_bvecs(np.array(pred_val)[None], np.array(pred_vec)[None, :], - b0_threshold=0) + prediction_gtab = gradient_table_from_bvals_bvecs( + np.array(pred_val)[None], np.array(pred_vec)[None, :], b0_threshold=0 + ) - if is_shelled and model_name == '3dshore': + if is_shelled and self.inputs.model_name == "3dshore": from dipy.reconst.shore import ShoreModel + radial_order = 6 zeta = 700 lambdaN = 1e-8 lambdaL = 1e-8 - estimator_shore = ShoreModel(training_gtab, radial_order=radial_order, - zeta=zeta, lambdaN=lambdaN, lambdaL=lambdaL) + estimator_shore = ShoreModel( + training_gtab, + radial_order=radial_order, + zeta=zeta, + lambdaN=lambdaN, + lambdaL=lambdaL, + ) estimator_shore_fit = estimator_shore.fit(training_data, mask=mask_array) pred_shore_fit = estimator_shore_fit.predict(prediction_gtab) - pred_shore_fit_file = os.path.join(runtime.cwd, - "predicted_shore_b%d_%.2f_%.2f_%.2f.nii.gz" % - ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + pred_shore_fit_file = os.path.join( + runtime.cwd, + "predicted_shore_b%d_%.2f_%.2f_%.2f.nii.gz" + % ((pred_val,) + tuple(np.round(pred_vec, decimals=2))), + ) output_data = pred_shore_fit[..., 0] - nb.Nifti1Image(output_data, mask_img.affine, mask_img.header).to_filename(pred_shore_fit_file) - elif model_name == 'sfm' and not is_shelled: + nb.Nifti1Image(output_data, mask_img.affine, mask_img.header).to_filename( + pred_shore_fit_file + ) + elif self.inputs.model_name == "sfm": + from sklearn.linear_model import Ridge import dipy.reconst.sfm as sfm from dipy.data import default_sphere - estimator_sfm = sfm.SparseFascicleModel(training_gtab, sphere=default_sphere, - l1_ratio=0.5, alpha=0.001) + estimator_sfm = sfm.SparseFascicleModel( + training_gtab, + sphere=default_sphere, + solver=Ridge(alpha=0.001, solver="lsqr"), + ) estimator_sfm_fit = estimator_sfm.fit(training_data, mask=mask_array) - pred_sfm_fit = estimator_sfm_fit.predict(prediction_gtab)[..., 0] + pred_sfm_fit = estimator_sfm_fit.predict(prediction_gtab) pred_sfm_fit[~mask_array] = 0 - pred_fit_file = os.path.join(runtime.cwd, "predicted_sfm_b%d_%.2f_%.2f_%.2f.npy" % - ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + pred_fit_file = os.path.join( + runtime.cwd, + "predicted_sfm_b%d_%.2f_%.2f_%.2f.npy" + % ((pred_val,) + tuple(np.round(pred_vec, decimals=2))), + ) np.save(pred_fit_file, pred_sfm_fit) - pred_fit_file = os.path.join(runtime.cwd, "predicted_sfm_b%d_%.2f_%.2f_%.2f.nii.gz" % - ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) - nb.Nifti1Image(pred_sfm_fit, mask_img.affine, mask_img.header).to_filename(pred_fit_file) - else: + pred_fit_file = os.path.join( + runtime.cwd, + "predicted_sfm_b%d_%.2f_%.2f_%.2f.nii.gz" + % ((pred_val,) + tuple(np.round(pred_vec, decimals=2))), + ) + nb.Nifti1Image(pred_sfm_fit, mask_img.affine, mask_img.header).to_filename( + pred_fit_file + ) + elif self.inputs.model_name == "tensor": from dipy.reconst.dti import TensorModel + estimator_ten = TensorModel(training_gtab) estimator_ten_fit = estimator_ten.fit(training_data, mask=mask_array) pred_ten_fit = estimator_ten_fit.predict(prediction_gtab)[..., 0] pred_ten_fit[~mask_array] = 0 - pred_fit_file = os.path.join(runtime.cwd, "predicted_ten_b%d_%.2f_%.2f_%.2f.npy" % - ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) + pred_fit_file = os.path.join( + runtime.cwd, + "predicted_ten_b%d_%.2f_%.2f_%.2f.npy" + % ((pred_val,) + tuple(np.round(pred_vec, decimals=2))), + ) np.save(pred_fit_file, pred_ten_fit) - pred_fit_file = os.path.join(runtime.cwd, "predicted_ten_b%d_%.2f_%.2f_%.2f.nii.gz" % - ((pred_val,) + tuple(np.round(pred_vec, decimals=2)))) - nb.Nifti1Image(pred_ten_fit, mask_img.affine, mask_img.header).to_filename(pred_fit_file) + pred_fit_file = os.path.join( + runtime.cwd, + "predicted_ten_b%d_%.2f_%.2f_%.2f.nii.gz" + % ((pred_val,) + tuple(np.round(pred_vec, decimals=2))), + ) + nb.Nifti1Image(pred_ten_fit, mask_img.affine, mask_img.header).to_filename( + pred_fit_file + ) + else: + raise ValueError("Model not supported.") + + self._results["predicted_image"] = pred_fit_file + + return runtime - self._results['predicted_image'] = pred_fit_file + +class CombineMotionsInputSpec(BaseInterfaceInputSpec): + transform_files = InputMultiObject( + File(exists=True), mandatory=True, desc="transform files from emc" + ) + source_files = InputMultiObject( + File(exists=True), mandatory=True, desc="Moving images" + ) + ref_file = File(exists=True, mandatory=True, desc="Fixed Image") + + +class CombineMotionsOututSpec(TraitedSpec): + motion_file = File(exists=True) + + +class CombineMotions(SimpleInterface): + input_spec = CombineMotionsInputSpec + output_spec = CombineMotionsOututSpec + + def _run_interface(self, runtime): + import pandas as pd + from dmriprep.utils.images import get_params + + output_fname = os.path.join(runtime.cwd, "motion_params.csv") + motion_parms_path = os.path.join(runtime.cwd, "movpar.txt") + motion_params = open(os.path.abspath(motion_parms_path), "w") + + collected_motion = [] + for aff in self.inputs.transform_files: + rotations, translations = get_params(np.load(aff)) + collected_motion.append(rotations + translations) + for i in rotations + translations: + motion_params.write("%f " % i) + motion_params.write("\n") + motion_params.close() + + final_motion = np.row_stack(collected_motion) + cols = ["rotateX", "rotateY", "rotateZ", "shiftX", "shiftY", "shiftZ"] + motion_df = pd.DataFrame(data=final_motion, columns=cols) + motion_df.to_csv(output_fname, index=False) + self._results['motion_file'] = output_fname return runtime @@ -291,9 +394,10 @@ def _run_interface(self, runtime): snr = np.nan_to_num(signal_var / noise_var) out_mat = np.zeros(mask_image.shape) out_mat[mask] = snr - nb.Nifti1Image(out_mat, mask_image.affine, - header=mask_image.header).to_filename(cnr_file) - self._results['cnr_image'] = cnr_file + nb.Nifti1Image( + out_mat, mask_image.affine, header=mask_image.header + ).to_filename(cnr_file) + self._results["cnr_image"] = cnr_file return runtime @@ -325,8 +429,9 @@ def _run_interface(self, runtime): warped_dwi_images = self.inputs.warped_dwi_images[::-1] model_transforms = self.inputs.model_based_transforms[::-1] model_images = self.inputs.model_predicted_images[::-1] - b0_transforms = [self.inputs.initial_transforms[idx] for idx in - self.inputs.b0_indices][::-1] + b0_transforms = [ + self.inputs.initial_transforms[idx] for idx in self.inputs.b0_indices + ][::-1] num_dwis = len(self.inputs.initial_transforms) for imagenum in range(num_dwis): @@ -342,8 +447,8 @@ def _run_interface(self, runtime): if not len(model_transforms) == len(b0_transforms) == len(model_images) == 0: raise Exception("Unable to recombine images and transforms") - self._results['emc_warped_images'] = full_warped_images - self._results['full_transforms'] = full_transforms - self._results['full_predicted_dwi_series'] = full_predicted_dwi_series + self._results["emc_warped_images"] = full_warped_images + self._results["full_transforms"] = full_transforms + self._results["full_predicted_dwi_series"] = full_predicted_dwi_series return runtime diff --git a/dmriprep/interfaces/register.py b/dmriprep/interfaces/register.py new file mode 100644 index 00000000..6e18af01 --- /dev/null +++ b/dmriprep/interfaces/register.py @@ -0,0 +1,147 @@ +"""Register tools interfaces.""" +import numpy as np +import nibabel as nb +import dmriprep +from nipype import logging +from pathlib import Path +from nipype.utils.filemanip import fname_presuffix +from nipype.interfaces.base import ( + traits, + TraitedSpec, + BaseInterfaceInputSpec, + InputMultiObject, + SimpleInterface, + File, +) + + +LOGGER = logging.getLogger("nipype.interface") + + +class _ApplyAffineInputSpec(BaseInterfaceInputSpec): + moving_image = File( + exists=True, mandatory=True, desc="image to apply transformation from" + ) + fixed_image = File( + exists=True, mandatory=True, desc="image to apply transformation to" + ) + transform_affine = InputMultiObject( + File(exists=True), mandatory=True, desc="transformation affine" + ) + invert_transform = traits.Bool(False, usedefault=True) + + +class _ApplyAffineOutputSpec(TraitedSpec): + warped_image = File(exists=True, desc="Outputs warped image") + + +class ApplyAffine(SimpleInterface): + """ + Interface to apply an affine transformation to an image. + """ + + input_spec = _ApplyAffineInputSpec + output_spec = _ApplyAffineOutputSpec + + def _run_interface(self, runtime): + from dmriprep.utils.register import apply_affine + + warped_image_nifti = apply_affine( + nb.load(self.inputs.moving_image), + nb.load(self.inputs.fixed_image), + np.load(self.inputs.transform_affine[0]), + self.inputs.invert_transform, + ) + cwd = Path(runtime.cwd).absolute() + warped_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_warped.nii.gz", + newpath=str(cwd), + ) + + warped_image_nifti.to_filename(warped_file) + + self._results["warped_image"] = warped_file + return runtime + + +class _RegisterInputSpec(BaseInterfaceInputSpec): + moving_image = File( + exists=True, mandatory=True, desc="image to apply transformation from" + ) + fixed_image = File( + exists=True, mandatory=True, desc="image to apply transformation to" + ) + nbins = traits.Int(default_value=32, usedefault=True) + sampling_prop = traits.Float(defualt_value=1, usedefault=True) + metric = traits.Str(default_value="MI", usedefault=True) + level_iters = traits.List( + trait=traits.Any(), value=[10000, 1000, 100], usedefault=True + ) + sigmas = traits.List(trait=traits.Any(), value=[5.0, 2.5, 0.0], usedefault=True) + factors = traits.List(trait=traits.Any(), value=[4, 2, 1], usedefault=True) + params0 = traits.ArrayOrNone(value=None, usedefault=True) + pipeline = traits.List( + trait=traits.Any(), + value=["c_of_mass", "translation", "rigid", "affine"], + usedefault=True, + ) + + +class _RegisterOutputSpec(TraitedSpec): + forward_transforms = traits.List( + File(exists=True), desc="List of output transforms for forward registration" + ) + warped_image = File(exists=True, desc="Outputs warped image") + + +class Register(SimpleInterface): + """ + Interface to perform affine registration. + """ + + input_spec = _RegisterInputSpec + output_spec = _RegisterOutputSpec + + def _run_interface(self, runtime): + from dmriprep.utils.register import affine_registration + + reg_types = ["c_of_mass", "translation", "rigid", "affine"] + pipeline = [ + getattr(dmriprep.utils.register, i) + for i in self.inputs.pipeline + if i in reg_types + ] + + warped_image_nifti, forward_transform_mat = affine_registration( + nb.load(self.inputs.moving_image), + nb.load(self.inputs.fixed_image), + self.inputs.nbins, + self.inputs.sampling_prop, + self.inputs.metric, + pipeline, + self.inputs.level_iters, + self.inputs.sigmas, + self.inputs.factors, + self.inputs.params0, + ) + cwd = Path(runtime.cwd).absolute() + warped_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_warped.nii.gz", + newpath=str(cwd), + ) + forward_transform_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_forward_transform.npy", + newpath=str(cwd), + ) + warped_image_nifti.to_filename(warped_file) + + np.save(forward_transform_file, forward_transform_mat) + self._results["warped_image"] = warped_file + self._results["forward_transforms"] = [forward_transform_file] + return runtime diff --git a/dmriprep/interfaces/reports.py b/dmriprep/interfaces/reports.py index e1b53fdc..dd85557e 100644 --- a/dmriprep/interfaces/reports.py +++ b/dmriprep/interfaces/reports.py @@ -163,20 +163,20 @@ def _run_interface(self, runtime): return runtime -class HMCReportInputSpec(BaseInterfaceInputSpec): +class EMCReportInputSpec(BaseInterfaceInputSpec): iteration_summary = File(exists=True) registered_images = InputMultiObject(File(exists=True)) original_images = InputMultiObject(File(exists=True)) model_predicted_images = InputMultiObject(File(exists=True)) -class HMCReportOutputSpec(SummaryOutputSpec): +class EMCReportOutputSpec(SummaryOutputSpec): plot_file = File(exists=True) -class HMCReport(SummaryInterface): - input_spec = HMCReportInputSpec - output_spec = HMCReportOutputSpec +class EMCReport(SummaryInterface): + input_spec = EMCReportInputSpec + output_spec = EMCReportOutputSpec def _run_interface(self, runtime): import imageio diff --git a/dmriprep/interfaces/vectors.py b/dmriprep/interfaces/vectors.py index 6880f623..fe7dc320 100644 --- a/dmriprep/interfaces/vectors.py +++ b/dmriprep/interfaces/vectors.py @@ -2,14 +2,17 @@ import os from pathlib import Path import numpy as np -import pandas as pd from nipype.utils.filemanip import fname_presuffix from nipype.interfaces.base import ( - SimpleInterface, BaseInterfaceInputSpec, TraitedSpec, - File, traits, isdefined, InputMultiObject + SimpleInterface, + BaseInterfaceInputSpec, + TraitedSpec, + File, + traits, + isdefined, ) -from ..utils.vectors import DiffusionGradientTable, reorient_vecs_from_ras_b, B0_THRESHOLD, BVEC_NORM_EPSILON -from subprocess import Popen, PIPE +from ..utils.vectors import DiffusionGradientTable, B0_THRESHOLD, BVEC_NORM_EPSILON + def _undefined(objekt, name, default=None): value = getattr(objekt, name) @@ -20,9 +23,9 @@ def _undefined(objekt, name, default=None): class _CheckGradientTableInputSpec(BaseInterfaceInputSpec): dwi_file = File(exists=True, mandatory=True) - in_bvec = File(exists=True, xor=['in_rasb']) - in_bval = File(exists=True, xor=['in_rasb']) - in_rasb = File(exists=True, xor=['in_bval', 'in_bvec']) + in_bvec = File(exists=True, xor=["in_rasb"]) + in_bval = File(exists=True, xor=["in_rasb"]) + in_rasb = File(exists=True, xor=["in_bval", "in_bvec"]) b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True) bvec_norm_epsilon = traits.Float(BVEC_NORM_EPSILON, usedefault=True) b_scale = traits.Bool(True, usedefault=True) @@ -72,36 +75,37 @@ class CheckGradientTable(SimpleInterface): output_spec = _CheckGradientTableOutputSpec def _run_interface(self, runtime): - rasb_file = _undefined(self.inputs, 'in_rasb') + rasb_file = _undefined(self.inputs, "in_rasb") table = DiffusionGradientTable( self.inputs.dwi_file, - bvecs=_undefined(self.inputs, 'in_bvec'), - bvals=_undefined(self.inputs, 'in_bval'), + bvecs=_undefined(self.inputs, "in_bvec"), + bvals=_undefined(self.inputs, "in_bval"), rasb_file=rasb_file, b_scale=self.inputs.b_scale, bvec_norm_epsilon=self.inputs.bvec_norm_epsilon, b0_threshold=self.inputs.b0_threshold, ) pole = table.pole - self._results['pole'] = tuple(pole) - self._results['full_sphere'] = np.all(pole == 0.0) - self._results['b0_ixs'] = np.where(table.b0mask)[0].tolist() + self._results["pole"] = tuple(pole) + self._results["full_sphere"] = np.all(pole == 0.0) + self._results["b0_ixs"] = np.where(table.b0mask)[0].tolist() cwd = Path(runtime.cwd).absolute() if rasb_file is None: rasb_file = fname_presuffix( - self.inputs.dwi_file, use_ext=False, suffix='.tsv', - newpath=str(cwd)) + self.inputs.dwi_file, use_ext=False, suffix=".tsv", newpath=str(cwd) + ) table.to_filename(rasb_file) - self._results['out_rasb'] = rasb_file - table.to_filename('%s/dwi' % cwd, filetype='fsl') - self._results['out_bval'] = str(cwd / 'dwi.bval') - self._results['out_bvec'] = str(cwd / 'dwi.bvec') + self._results["out_rasb"] = rasb_file + table.to_filename("%s/dwi" % cwd, filetype="fsl") + self._results["out_bval"] = str(cwd / "dwi.bval") + self._results["out_bvec"] = str(cwd / "dwi.bvec") return runtime class _ReorientVectorsInputSpec(BaseInterfaceInputSpec): + dwi_file = File(exists=True) rasb_file = File(exists=True) affines = traits.List() b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True) @@ -130,7 +134,7 @@ class ReorientVectors(SimpleInterface): >>> res = reor_vecs.run() >>> out_rasb = res.outputs.out_rasb >>> out_rasb_mat = np.loadtxt(out_rasb, skiprows=1) - >>> npt.assert_equal(oldrasb_mat, out_rasb_mat) + >>> assert oldrasb_mat == out_rasb_mat True """ @@ -139,90 +143,29 @@ class ReorientVectors(SimpleInterface): def _run_interface(self, runtime): from nipype.utils.filemanip import fname_presuffix - reor_table = reorient_vecs_from_ras_b( + + table = DiffusionGradientTable( + dwi_file=self.inputs.dwi_file, rasb_file=self.inputs.rasb_file, - affines=self.inputs.affines, - b0_threshold=self.inputs.b0_threshold, + transforms=self.inputs.affines, ) + table.generate_vecval() + reor_table = table.reorient_rasb() cwd = Path(runtime.cwd).absolute() reor_rasb_file = fname_presuffix( - self.inputs.rasb_file, use_ext=False, suffix='_reoriented.tsv', - newpath=str(cwd)) - np.savetxt(str(reor_rasb_file), reor_table, - delimiter='\t', header='\t'.join('RASB'), - fmt=['%.8f'] * 3 + ['%g']) - - self._results['out_rasb'] = reor_rasb_file - return runtime - - -def get_fsl_motion_params(itk_file, src_file, ref_file, working_dir): - tmp_fsl_file = fname_presuffix(itk_file, newpath=working_dir, - suffix='_FSL.xfm', use_ext=False) - fsl_convert_cmd = "c3d_affine_tool " \ - "-ref {ref_file} " \ - "-src {src_file} " \ - "-itk {itk_file} " \ - "-ras2fsl -o {fsl_file}".format( - src_file=src_file, ref_file=ref_file, itk_file=itk_file, - fsl_file=tmp_fsl_file) - os.system(fsl_convert_cmd) - proc = Popen(['avscale', '--allparams', tmp_fsl_file, src_file], stdout=PIPE, - stderr=PIPE) - stdout, _ = proc.communicate() - - def get_measures(line): - line = line.strip().split() - return np.array([float(num) for num in line[-3:]]) - - lines = stdout.decode("utf-8").split("\n") - flip = np.array([1, -1, -1]) - rotation = get_measures(lines[6]) * flip - translation = get_measures(lines[8]) * flip - scale = get_measures(lines[10]) - shear = get_measures(lines[12]) - - return np.concatenate([scale, shear, rotation, translation]) - - -class CombineMotionsInputSpec(BaseInterfaceInputSpec): - transform_files = InputMultiObject(File(exists=True), mandatory=True, - desc='transform files from hmc') - source_files = InputMultiObject(File(exists=True), mandatory=True, - desc='Moving images') - ref_file = File(exists=True, mandatory=True, desc='Fixed Image') - - -class CombineMotionsOututSpec(TraitedSpec): - motion_file = File(exists=True) - spm_motion_file = File(exists=True) - - -class CombineMotions(SimpleInterface): - input_spec = CombineMotionsInputSpec - output_spec = CombineMotionsOututSpec - - def _run_interface(self, runtime): - collected_motion = [] - output_fname = os.path.join(runtime.cwd, "motion_params.csv") - output_spm_fname = os.path.join(runtime.cwd, "spm_movpar.txt") - ref_file = self.inputs.ref_file - for motion_file, src_file in zip(self.inputs.transform_files, - self.inputs.source_files): - collected_motion.append( - get_fsl_motion_params(motion_file, src_file, ref_file, runtime.cwd)) - - final_motion = np.row_stack(collected_motion) - cols = ["scaleX", "scaleY", "scaleZ", "shearXY", "shearXZ", - "shearYZ", "rotateX", "rotateY", "rotateZ", "shiftX", "shiftY", - "shiftZ"] - motion_df = pd.DataFrame(data=final_motion, columns=cols) - motion_df.to_csv(output_fname, index=False) - self._results['motion_file'] = output_fname - - spmcols = motion_df[['shiftX', 'shiftY', 'shiftZ', 'rotateX', 'rotateY', 'rotateZ']] - self._results['spm_motion_file'] = output_spm_fname - np.savetxt(output_spm_fname, spmcols.values) + self.inputs.rasb_file, + use_ext=False, + suffix="_reoriented.tsv", + newpath=str(cwd), + ) + np.savetxt( + str(reor_rasb_file), + reor_table, + delimiter="\t", + header="\t".join("RASB"), + fmt=["%.8f"] * 3 + ["%g"], + ) + self._results["out_rasb"] = reor_rasb_file return runtime diff --git a/dmriprep/utils/images.py b/dmriprep/utils/images.py index 7840bb39..2ce1c681 100644 --- a/dmriprep/utils/images.py +++ b/dmriprep/utils/images.py @@ -5,17 +5,16 @@ def extract_b0(in_file, b0_ixs, newpath=None): """Extract the *b0* volumes from a DWI dataset.""" - out_file = fname_presuffix( - in_file, suffix='_b0', newpath=newpath) + out_file = fname_presuffix(in_file, suffix="_b0", newpath=newpath) img = nb.load(in_file) - data = img.get_fdata(dtype='float32') + data = img.get_fdata(dtype="float32") b0 = data[..., b0_ixs] hdr = img.header.copy() hdr.set_data_shape(b0.shape) - hdr.set_xyzt_units('mm') + hdr.set_xyzt_units("mm") hdr.set_data_dtype(np.float32) nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file) return out_file @@ -23,16 +22,15 @@ def extract_b0(in_file, b0_ixs, newpath=None): def rescale_b0(in_file, mask_file, newpath=None): """Rescale the input volumes using the median signal intensity.""" - out_file = fname_presuffix( - in_file, suffix='_rescaled_b0', newpath=newpath) + out_file = fname_presuffix(in_file, suffix="_rescaled_b0", newpath=newpath) img = nb.load(in_file) if img.dataobj.ndim == 3: return in_file - data = img.get_fdata(dtype='float32') + data = img.get_fdata(dtype="float32") mask_img = nb.load(mask_file) - mask_data = mask_img.get_fdata(dtype='float32') + mask_data = mask_img.get_fdata(dtype="float32") median_signal = np.median(data[mask_data > 0, ...], axis=0) rescaled_data = 1000 * data / median_signal @@ -43,8 +41,7 @@ def rescale_b0(in_file, mask_file, newpath=None): def median(in_file, newpath=None): """Average a 4D dataset across the last dimension using median.""" - out_file = fname_presuffix( - in_file, suffix='_b0ref', newpath=newpath) + out_file = fname_presuffix(in_file, suffix="_b0ref", newpath=newpath) img = nb.load(in_file) if img.dataobj.ndim == 3: @@ -53,16 +50,26 @@ def median(in_file, newpath=None): nb.squeeze_image(img).to_filename(out_file) return out_file - median_data = np.median(img.get_fdata(dtype='float32'), - axis=-1) + median_data = np.median(img.get_fdata(dtype="float32"), axis=-1) hdr = img.header.copy() - hdr.set_xyzt_units('mm') + hdr.set_xyzt_units("mm") hdr.set_data_dtype(np.float32) nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file) return out_file +def average_images(images): + from nilearn.image import mean_img + + average_img = mean_img([nb.load(img) for img in images]) + output_average_image = fname_presuffix( + images[0], use_ext=False, suffix="_mean.nii.gz" + ) + average_img.to_filename(output_average_image) + return output_average_image + + def quick_load_images(image_list, dtype=np.float32): example_img = nb.load(image_list[0]) num_images = len(image_list) @@ -82,7 +89,7 @@ def match_transforms(dwi_files, transforms, b0_indices): # Do sanity checks if not len(transforms) == len(b0_indices): - raise Exception('number of transforms does not match number of b0 images') + raise Exception("number of transforms does not match number of b0 images") # Create a list of which emc affines go with each of the split images nearest_affines = [] @@ -98,7 +105,7 @@ def save_4d_to_3d(in_file): files_3d = nb.four_to_three(nb.load(in_file)) out_files = [] for i, file_3d in enumerate(files_3d): - out_file = fname_presuffix(in_file, suffix='_tmp_{}'.format(i)) + out_file = fname_presuffix(in_file, suffix="_tmp_{}".format(i)) file_3d.to_filename(out_file) out_files.append(out_file) del files_3d @@ -106,20 +113,64 @@ def save_4d_to_3d(in_file): def prune_b0s_from_dwis(in_files, b0_ixs): - if in_files[0].endswith('_trans.nii.gz'): - out_files = [i for j, i in enumerate(sorted(in_files, - key=lambda x: int(x.split("_")[-2].split('.nii.gz')[0]))) if j - not in b0_ixs] + if in_files[0].endswith("_warped.nii.gz"): + out_files = [ + i + for j, i in enumerate( + sorted( + in_files, key=lambda x: int(x.split("_")[-2].split(".nii.gz")[0]) + ) + ) + if j not in b0_ixs + ] else: - out_files = [i for j, i in enumerate(sorted(in_files, - key=lambda x: int(x.split("_")[-1].split('.nii.gz')[0]))) if j - not in b0_ixs] + out_files = [ + i + for j, i in enumerate( + sorted( + in_files, key=lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) + ) + ) + if j not in b0_ixs + ] return out_files def save_3d_to_4d(in_files): img_4d = nb.funcs.concat_images([nb.load(img_3d) for img_3d in in_files]) - out_file = fname_presuffix(in_files[0], suffix='_merged') + out_file = fname_presuffix(in_files[0], suffix="_merged") img_4d.to_filename(out_file) del img_4d return out_file + + +def get_params(A): + """This is a copy of spm's spm_imatrix where + we already know the rotations and translations matrix, + shears and zooms (as outputs from fsl FLIRT/avscale) + Let A = the 4x4 rotation and translation matrix + R = [ c5*c6, c5*s6, s5] + [-s4*s5*c6-c4*s6, -s4*s5*s6+c4*c6, s4*c5] + [-c4*s5*c6+s4*s6, -c4*s5*s6-s4*c6, c4*c5] + """ + + def rang(b): + a = min(max(b, -1), 1) + return a + + Ry = np.arcsin(A[0, 2]) + # Rx = np.arcsin(A[1, 2] / np.cos(Ry)) + # Rz = np.arccos(A[0, 1] / np.sin(Ry)) + + if (abs(Ry) - np.pi / 2) ** 2 < 1e-9: + Rx = 0 + Rz = np.arctan2(-rang(A[1, 0]), rang(-A[2, 0] / A[0, 2])) + else: + c = np.cos(Ry) + Rx = np.arctan2(rang(A[1, 2] / c), rang(A[2, 2] / c)) + Rz = np.arctan2(rang(A[0, 1] / c), rang(A[0, 0] / c)) + + rotations = [Rx, Ry, Rz] + translations = [A[0, 3], A[1, 3], A[2, 3]] + + return rotations, translations diff --git a/dmriprep/utils/register.py b/dmriprep/utils/register.py new file mode 100644 index 00000000..7388ae8c --- /dev/null +++ b/dmriprep/utils/register.py @@ -0,0 +1,436 @@ +""" + +Registration API: simplified API for registration of MRI data and of +streamlines + + +""" +import numpy as np +import nibabel as nb +from dipy.align.metrics import CCMetric, EMMetric, SSDMetric +from dipy.align.imwarp import SymmetricDiffeomorphicRegistration, DiffeomorphicMap + +from dipy.align.imaffine import ( + transform_centers_of_mass, + AffineMap, + MutualInformationMetric, + AffineRegistration, +) + +from dipy.align.transforms import ( + TranslationTransform3D, + RigidTransform3D, + AffineTransform3D, +) +from nipype.utils.filemanip import fname_presuffix +import dipy.core.gradients as dpg +import dipy.data as dpd + +syn_metric_dict = {"CC": CCMetric, "EM": EMMetric, "SSD": SSDMetric} + +__all__ = [ + "syn_registration", + "syn_register_dwi", + "write_mapping", + "read_mapping", + "c_of_mass", + "translation", + "rigid", + "affine", + "affine_registration", + "register_series", + "register_dwi", +] + + +def syn_registration( + moving, + static, + moving_affine=None, + static_affine=None, + step_length=0.25, + metric="CC", + dim=3, + level_iters=[10, 10, 5], + sigma_diff=2.0, + radius=4, + prealign=None, +): + """Register a source image (moving) to a target image (static). + + Parameters + ---------- + moving : ndarray + The source image data to be registered + moving_affine : array, shape (4,4) + The affine matrix associated with the moving (source) data. + static : ndarray + The target image data for registration + static_affine : array, shape (4,4) + The affine matrix associated with the static (target) data + metric : string, optional + The metric to be optimized. One of `CC`, `EM`, `SSD`, + Default: CCMetric. + dim: int (either 2 or 3), optional + The dimensions of the image domain. Default: 3 + level_iters : list of int, optional + the number of iterations at each level of the Gaussian Pyramid (the + length of the list defines the number of pyramid levels to be + used). + sigma_diff, radius : float + Parameters for initialization of the metric. + + Returns + ------- + warped_moving : ndarray + The data in `moving`, warped towards the `static` data. + forward : ndarray (..., 3) + The vector field describing the forward warping from the source to the + target. + backward : ndarray (..., 3) + The vector field describing the backward warping from the target to the + source. + """ + use_metric = syn_metric_dict[metric](dim, sigma_diff=sigma_diff, radius=radius) + + sdr = SymmetricDiffeomorphicRegistration( + use_metric, level_iters, step_length=step_length + ) + + mapping = sdr.optimize( + static, + moving, + static_grid2world=static_affine, + moving_grid2world=moving_affine, + prealign=prealign, + ) + + warped_moving = mapping.transform(moving) + return warped_moving, mapping + + +def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): + """ + Register DWI data to a template. + + Parameters + ----------- + dwi : nifti image or str + Image containing DWI data, or full path to a nifti file with DWI. + gtab : GradientTable or list of strings + The gradients associated with the DWI data, or a string with [fbcal, ] + template : nifti image or str, optional + + syn_kwargs : key-word arguments for :func:`syn_registration` + + Returns + ------- + DiffeomorphicMap object + """ + if template is None: + template = dpd.read_mni_template() + if isinstance(template, str): + template = nb.load(template) + + template_data = template.get_fdata() + template_affine = template.affine + + if isinstance(dwi, str): + dwi = nb.load(dwi) + + if not isinstance(gtab, dpg.GradientTable): + gtab = dpg.gradient_table(*gtab) + + dwi_affine = dwi.affine + dwi_data = dwi.get_fdata() + mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) + warped_b0, mapping = syn_registration( + mean_b0, + template_data, + moving_affine=dwi_affine, + static_affine=template_affine, + **syn_kwargs + ) + return warped_b0, mapping + + +def write_mapping(mapping, fname): + """ + Write out a syn registration mapping to file + + Parameters + ---------- + mapping : a DiffeomorphicMap object derived from :func:`syn_registration` + fname : str + Full path to the nifti file storing the mapping + + """ + mapping_data = np.array([mapping.forward.T, mapping.backward.T]).T + nb.save(nb.Nifti1Image(mapping_data, mapping.codomain_world2grid), fname) + + +def read_mapping(disp, domain_img, codomain_img, prealign=None): + """ + Read a syn registration mapping from a nifti file + + Parameters + ---------- + disp : str or Nifti1Image + A file of image containing the mapping displacement field in each voxel + Shape (x, y, z, 3, 2) + + domain_img : str or Nifti1Image + + codomain_img : str or Nifti1Image + + Returns + ------- + A :class:`DiffeomorphicMap` object + """ + if isinstance(disp, str): + disp = nb.load(disp) + + if isinstance(domain_img, str): + domain_img = nb.load(domain_img) + + if isinstance(codomain_img, str): + codomain_img = nb.load(codomain_img) + + mapping = DiffeomorphicMap( + 3, + disp.shape[:3], + disp_grid2world=np.linalg.inv(disp.affine), + domain_shape=domain_img.shape[:3], + domain_grid2world=domain_img.affine, + codomain_shape=codomain_img.shape, + codomain_grid2world=codomain_img.affine, + prealign=prealign, + ) + + disp_data = disp.get_fdata().astype(np.float32) + mapping.forward = disp_data[..., 0] + mapping.backward = disp_data[..., 1] + mapping.is_inverse = True + + return mapping + + +def apply_affine(moving, static, transform_affine, invert=False): + """Apply an affine to transform an image from one space to another. + + Parameters + ---------- + moving : array + The image to be resampled + + static : array + + Returns + ------- + warped_img : the moving array warped into the static array's space. + + """ + affine_map = AffineMap( + transform_affine, static.shape, static.affine, moving.shape, moving.affine + ) + if invert is True: + warped_arr = affine_map.transform_inverse(np.asarray(moving.dataobj)) + else: + warped_arr = affine_map.transform(np.asarray(moving.dataobj)) + + return nb.Nifti1Image(warped_arr, static.affine) + + +def average_affines(transforms): + affine_list = [np.load(aff) for aff in transforms] + average_affine_file = fname_presuffix( + transforms[0], use_ext=False, suffix="_average.npy" + ) + np.save(average_affine_file, np.mean(affine_list, axis=0)) + return average_affine_file + + +# Affine registration pipeline: +affine_metric_dict = {"MI": MutualInformationMetric, "CC": CCMetric} + + +def c_of_mass( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = transform_centers_of_mass(static, static_affine, moving, moving_affine) + transformed = transform.transform(moving) + return transformed, transform.affine + + +def translation( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = TranslationTransform3D() + translation = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return translation.transform(moving), translation.affine + + +def rigid( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = RigidTransform3D() + rigid = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + return rigid.transform(moving), rigid.affine + + +def affine( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = AffineTransform3D() + affine = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return affine.transform(moving), affine.affine + + +def affine_registration( + moving, + static, + nbins, + sampling_prop, + metric, + pipeline, + level_iters, + sigmas, + factors, + params0, +): + + """ + Find the affine transformation between two 3D images. + + Parameters + ---------- + + """ + # Define the Affine registration object we'll use with the chosen metric: + use_metric = affine_metric_dict[metric](nbins, sampling_prop) + affreg = AffineRegistration( + metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors + ) + + if not params0: + starting_affine = np.eye(4) + else: + starting_affine = params0 + + # Go through the selected transformation: + for func in pipeline: + transformed, starting_affine = func( + np.asarray(moving.dataobj), + np.asarray(static.dataobj), + static.affine, + moving.affine, + affreg, + starting_affine, + params0, + ) + return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine + + +def register_series(series, ref, pipeline): + """Register a series to a reference image. + + Parameters + ---------- + series : Nifti1Image object + The data is 4D with the last dimension separating different 3D volumes + ref : Nifti1Image or integer or iterable + + Returns + ------- + transformed_list, affine_list + """ + if isinstance(ref, nb.Nifti1Image): + static = ref + static_data = static.get_fdata() + s_aff = static.affine + moving = series + moving_data = moving.get_fdata() + m_aff = moving.affine + + elif isinstance(ref, int) or np.iterable(ref): + data = series.get_fdata() + idxer = np.zeros(data.shape[-1]).astype(bool) + idxer[ref] = True + static_data = data[..., idxer] + if len(static_data.shape) > 3: + static_data = np.mean(static_data, -1) + moving_data = data[..., ~idxer] + m_aff = s_aff = series.affine + + affine_list = [] + transformed_list = [] + for ii in range(moving_data.shape[-1]): + this_moving = moving_data[..., ii] + transformed, affine = affine_registration( + this_moving, + static_data, + moving_affine=m_aff, + static_affine=s_aff, + pipeline=pipeline, + ) + transformed_list.append(transformed) + affine_list.append(affine) + + return transformed_list, affine_list + + +def register_dwi( + data, gtab, affine, b0_ref=0, pipeline=[c_of_mass, translation, rigid, affine] +): + """ + Register a DWI data-set + + Parameters + ---------- + data : 4D array + Diffusion data. + + gtab : a GradientTable class instance. + + """ + if np.sum(gtab.b0s_mask) > 1: + # First, register the b0s into one image: + b0_img = nb.Nifti1Image(data[..., gtab.b0s_mask], affine) + trans_b0 = register_series(b0_img, ref=0, pipeline=pipeline) + ref_data = np.mean(trans_b0, -1) + else: + ref_data = data[..., gtab.b0s_mask] + + # Construct a series out of the DWI and the registered mean B0: + series = nb.Nifti1Image( + np.concatenate([ref_data, data[..., ~gtab.b0s_mask]], -1), affine + ) + + transformed_list, affine_list = register_series(series, ref=0, pipeline=pipeline) + return nb.Nifti1Image(np.array(transformed_list), affine) diff --git a/dmriprep/utils/vectors.py b/dmriprep/utils/vectors.py index 425b77d0..2be4fa7c 100644 --- a/dmriprep/utils/vectors.py +++ b/dmriprep/utils/vectors.py @@ -12,11 +12,29 @@ class DiffusionGradientTable: """Data structure for DWI gradients.""" - __slots__ = ['_affine', '_gradients', '_b_scale', '_bvecs', '_bvals', '_normalized', - '_b0_thres', '_bvec_norm_epsilon'] - - def __init__(self, dwi_file=None, bvecs=None, bvals=None, rasb_file=None, - b_scale=True, b0_threshold=B0_THRESHOLD, bvec_norm_epsilon=BVEC_NORM_EPSILON): + __slots__ = [ + "_affine", + "_gradients", + "_b_scale", + "_bvecs", + "_bvals", + "_normalized", + "_transforms", + "_b0_thres", + "_bvec_norm_epsilon", + ] + + def __init__( + self, + dwi_file=None, + bvecs=None, + bvals=None, + rasb_file=None, + b_scale=True, + transforms=None, + b0_threshold=B0_THRESHOLD, + bvec_norm_epsilon=BVEC_NORM_EPSILON, + ): """ Create a new table of diffusion gradients. @@ -34,8 +52,8 @@ def __init__(self, dwi_file=None, bvecs=None, bvals=None, rasb_file=None, then bvecs and bvals will be dismissed. b_scale : bool Whether b-values should be normalized. - """ + self._transforms = transforms self._b_scale = b_scale self._b0_thres = b0_threshold self._bvec_norm_epsilon = bvec_norm_epsilon @@ -87,7 +105,7 @@ def affine(self, value): dwi_file = nb.load(str(value)) self._affine = dwi_file.affine.copy() return - if hasattr(value, 'affine'): + if hasattr(value, "affine"): self._affine = value.affine self._affine = np.array(value) @@ -102,12 +120,12 @@ def bvecs(self, value): if isinstance(value, (str, Path)): value = np.loadtxt(str(value)).T else: - value = np.array(value, dtype='float32') + value = np.array(value, dtype="float32") # Correct any b0's in bvecs misstated as 10's. value[np.any(abs(value) >= 10, axis=1)] = np.zeros(3) if self.bvals is not None and value.shape[0] != self.bvals.shape[0]: - raise ValueError('The number of b-vectors and b-values do not match') + raise ValueError("The number of b-vectors and b-values do not match") self._bvecs = value @bvals.setter @@ -115,7 +133,7 @@ def bvals(self, value): if isinstance(value, (str, Path)): value = np.loadtxt(str(value)).flatten() if self.bvecs is not None and value.shape[0] != self.bvecs.shape[0]: - raise ValueError('The number of b-vectors and b-values do not match') + raise ValueError("The number of b-vectors and b-values do not match") self._bvals = np.array(value) @property @@ -129,10 +147,12 @@ def normalize(self): return self._bvecs, self._bvals = normalize_gradients( - self.bvecs, self.bvals, + self.bvecs, + self.bvals, b0_threshold=self._b0_thres, bvec_norm_epsilon=self._bvec_norm_epsilon, - b_scale=self._b_scale) + b_scale=self._b_scale, + ) self._normalized = True def generate_rasb(self): @@ -142,14 +162,50 @@ def generate_rasb(self): _ras = bvecs2ras(self.affine, self.bvecs) self.gradients = np.hstack((_ras, self.bvals[..., np.newaxis])) + def reorient_rasb(self): + """Reorient the vectors based o a list of affine transforms.""" + from dipy.core.gradients import gradient_table_from_bvals_bvecs, reorient_bvecs + + affines = self._transforms.copy() + bvals = self._bvals + bvecs = self._bvecs + + # Verify that number of non-B0 volumes corresponds to the number of affines. + # If not, raise an error. + if len(self._bvals[self._bvals >= self._b0_thres]) != len(affines): + b0_indices = np.where(self._bvals <= self._b0_thres)[0].tolist() + if len(self._bvals[self._bvals >= self._b0_thres]) < len(affines): + for i in sorted(b0_indices, reverse=True): + del affines[i] + if len(self._bvals[self._bvals >= self._b0_thres]) > len(affines): + ras_b_mat = self._gradients.copy() + ras_b_mat = np.delete(ras_b_mat, tuple(b0_indices), axis=0) + bvals = ras_b_mat[:, 3] + bvecs = ras_b_mat[:, 0:3] + if len(self._bvals[self._bvals > self._b0_thres]) != len(affines): + raise ValueError( + "Affine transformations do not correspond to gradients" + ) + + # Build gradient table object + gt = gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=self._b0_thres) + + # Reorient table + new_gt = reorient_bvecs(gt, [np.load(aff) for aff in affines]) + + return np.hstack((new_gt.bvecs, new_gt.bvals[..., np.newaxis])) + def generate_vecval(self): """Compose a bvec/bval pair in image coordinates.""" if self.bvecs is None or self.bvals is None: if self.affine is None: raise TypeError( "Cannot generate b-vectors & b-values in image coordinates. " - "Please set the corresponding DWI image's affine matrix.") - self._bvecs = bvecs2ras(np.linalg.inv(self.affine), self.gradients[..., :-1]) + "Please set the corresponding DWI image's affine matrix." + ) + self._bvecs = bvecs2ras( + np.linalg.inv(self.affine), self.gradients[..., :-1] + ) self._bvals = self.gradients[..., -1].flatten() @property @@ -161,25 +217,36 @@ def pole(self): """ self.generate_rasb() - return calculate_pole(self.gradients[..., :-1], bvec_norm_epsilon=self._bvec_norm_epsilon) + return calculate_pole( + self.gradients[..., :-1], bvec_norm_epsilon=self._bvec_norm_epsilon + ) - def to_filename(self, filename, filetype='rasb'): + def to_filename(self, filename, filetype="rasb"): """Write files (RASB, bvecs/bvals) to a given path.""" - if filetype.lower() == 'rasb': + if filetype.lower() == "rasb": self.generate_rasb() - np.savetxt(str(filename), self.gradients, - delimiter='\t', header='\t'.join('RASB'), - fmt=['%.8f'] * 3 + ['%g']) - elif filetype.lower() == 'fsl': + np.savetxt( + str(filename), + self.gradients, + delimiter="\t", + header="\t".join("RASB"), + fmt=["%.8f"] * 3 + ["%g"], + ) + elif filetype.lower() == "fsl": self.generate_vecval() - np.savetxt('%s.bvec' % filename, self.bvecs.T, fmt='%.6f') - np.savetxt('%s.bval' % filename, self.bvals, fmt='%.6f') + np.savetxt("%s.bvec" % filename, self.bvecs.T, fmt="%.6f") + np.savetxt("%s.bval" % filename, self.bvals, fmt="%.6f") else: raise ValueError('Unknown filetype "%s"' % filetype) -def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, - bvec_norm_epsilon=BVEC_NORM_EPSILON, b_scale=True): +def normalize_gradients( + bvecs, + bvals, + b0_threshold=B0_THRESHOLD, + bvec_norm_epsilon=BVEC_NORM_EPSILON, + b_scale=True, +): """ Normalize b-vectors and b-values. @@ -235,8 +302,8 @@ def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, True """ - bvals = np.array(bvals, dtype='float32') - bvecs = np.array(bvecs, dtype='float32') + bvals = np.array(bvals, dtype="float32") + bvecs = np.array(bvecs, dtype="float32") b0s = bvals < b0_threshold b0_vecs = np.linalg.norm(bvecs, axis=1) < bvec_norm_epsilon @@ -244,8 +311,9 @@ def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, # Check for bval-bvec discrepancy. if not np.all(b0s == b0_vecs): raise ValueError( - 'Inconsistent bvals and bvecs (%d, %d low-b, respectively).' % - (b0s.sum(), b0_vecs.sum())) + "Inconsistent bvals and bvecs (%d, %d low-b, respectively)." + % (b0s.sum(), b0_vecs.sum()) + ) # Rescale b-vals if requested if b_scale: @@ -259,7 +327,7 @@ def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, # Rescale b-vecs, skipping b0's, on the appropriate axis to unit-norm length. bvecs[~b0s] /= np.linalg.norm(bvecs[~b0s], axis=1)[..., np.newaxis] - return bvecs, bvals.astype('uint16') + return bvecs, bvals.astype("uint16") def calculate_pole(bvecs, bvec_norm_epsilon=BVEC_NORM_EPSILON): @@ -295,7 +363,7 @@ def calculate_pole(bvecs, bvec_norm_epsilon=BVEC_NORM_EPSILON): https://rstudio-pubs-static.s3.amazonaws.com/27121_a22e51b47c544980bad594d5e0bb2d04.html """ - bvecs = np.array(bvecs, dtype='float32') # Normalize inputs + bvecs = np.array(bvecs, dtype="float32") # Normalize inputs b0s = np.linalg.norm(bvecs, axis=1) < bvec_norm_epsilon bvecs = bvecs[~b0s] @@ -367,7 +435,7 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2): if affine.shape == (4, 4): affine = affine[:3, :3] - bvecs = np.array(bvecs, dtype='float32') # Normalize inputs + bvecs = np.array(bvecs, dtype="float32") # Normalize inputs rotated_bvecs = affine[np.newaxis, ...].dot(bvecs.T)[0].T if norm is True: norms_bvecs = np.linalg.norm(rotated_bvecs, axis=1) @@ -377,77 +445,9 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2): return rotated_bvecs -def reorient_vecs_from_ras_b(rasb_file, affines, b0_threshold=B0_THRESHOLD): - """ - Reorient the vectors from a rasb .tsv file. - When correcting for motion, rotation of the diffusion-weighted volumes - might cause systematic bias in rotationally invariant measures, such as FA - and MD, and also cause characteristic biases in tractography, unless the - gradient directions are appropriately reoriented to compensate for this - effect [Leemans2009]_. - Parameters - ---------- - rasb_file : str or os.pathlike - File path to a RAS-B gradient table. - affines : list or ndarray of shape (n, 4, 4) or (n, 3, 3) - Each entry in this list or array contain either an affine - transformation (4,4) or a rotation matrix (3, 3). - In both cases, the transformations encode the rotation that was applied - to the image corresponding to one of the non-zero gradient directions. - Returns - ------- - Gradients : ndarray of shape (4, n) - A reoriented ndarray where the first three columns correspond to each of - x, y, z directions of the bvecs, in R-A-S image orientation. - """ - from dipy.core.gradients import gradient_table_from_bvals_bvecs, reorient_bvecs - from scipy.io import loadmat - - ras_b_mat = np.genfromtxt(rasb_file, delimiter='\t') - - # Verify that number of non-B0 volumes corresponds to the number of affines. - # If not, raise an error. - if len(ras_b_mat[:, 3][ras_b_mat[:, 3] <= b0_threshold]) != len(affines): - b0_indices = np.where(ras_b_mat[:, 3] <= b0_threshold)[0].tolist() - for i in sorted(b0_indices, reverse=True): - del affines[i] - if len(ras_b_mat[:, 3][ras_b_mat[:, 3] > b0_threshold]) != len(affines): - raise ValueError('Affine transformations do not correspond to gradients') - - # Build gradient table object - gt = gradient_table_from_bvals_bvecs(ras_b_mat[:, 3], ras_b_mat[:, 0:3], - b0_threshold=b0_threshold) - - # Reorient table - ras_trans = np.ones(shape=(4, 4)) - ras_trans[0, 1] = -ras_trans[0, 1] - ras_trans[1, 0] = -ras_trans[1, 0] - ras_trans[2, 3] = -ras_trans[2, 3] - affines_ras = [] - for aff in affines: - aff_mat = loadmat(aff) - M = np.zeros(shape=(4, 4)) - M[0, 0] = aff_mat['AffineTransform_float_3_3'][0] - M[0, 1] = aff_mat['AffineTransform_float_3_3'][1] - M[0, 2] = aff_mat['AffineTransform_float_3_3'][2] - M[1, 0] = aff_mat['AffineTransform_float_3_3'][3] - M[1, 1] = aff_mat['AffineTransform_float_3_3'][4] - M[1, 2] = aff_mat['AffineTransform_float_3_3'][5] - M[2, 0] = aff_mat['AffineTransform_float_3_3'][6] - M[2, 1] = aff_mat['AffineTransform_float_3_3'][7] - M[2, 2] = aff_mat['AffineTransform_float_3_3'][8] - M[3, 3] = 1 - M[0:3, 3] = aff_mat['fixed'].T - affines_ras.append(np.multiply(M, ras_trans)) - del M - - new_gt = reorient_bvecs(gt, affines_ras) - - return np.hstack((new_gt.bvecs, new_gt.bvals[..., np.newaxis])) - - -def _nonoverlapping_qspace_samples(prediction_bval, prediction_bvec, - all_bvals, all_bvecs, cutoff): +def _nonoverlapping_qspace_samples( + prediction_bval, prediction_bvec, all_bvals, all_bvecs, cutoff +): """Ensure that none of the training samples are too close to the sample to predict. Parameters """ @@ -458,26 +458,43 @@ def _nonoverlapping_qspace_samples(prediction_bval, prediction_bvec, # Convert q values to percent of maximum qval max_qval = max(max(all_qvals), prediction_qval) all_qvals_scaled = all_qvals / max_qval * 100 - prediction_qval_scaled = prediction_qval / max_qval * 100 scaled_qvecs = all_bvecs * all_qvals_scaled[:, np.newaxis] - scaled_prediction_qvec = prediction_bvec * prediction_qval_scaled + scaled_prediction_qvec = prediction_bvec * (prediction_qval / max_qval * 100) # Calculate the distance between the sampled qvecs and the prediction qvec - distances = np.linalg.norm(scaled_qvecs - scaled_prediction_qvec, axis=1) - distances_flip = np.linalg.norm(scaled_qvecs + scaled_prediction_qvec, axis=1) - ok_samples = (distances > cutoff) * (distances_flip > cutoff) + ok_samples = ( + np.linalg.norm(scaled_qvecs - scaled_prediction_qvec, axis=1) > cutoff + ) * (np.linalg.norm(scaled_qvecs + scaled_prediction_qvec, axis=1) > cutoff) return ok_samples def _rasb_to_bvec_list(in_rasb): + """ + Create a list of b-vectors from a rasb gradient table. + + Parameters + ---------- + in_rasb : str or os.pathlike + File path to a RAS-B gradient table. + """ import numpy as np - ras_b_mat = np.genfromtxt(in_rasb, delimiter='\t') + + ras_b_mat = np.genfromtxt(in_rasb, delimiter="\t") bvec = [vec for vec in ras_b_mat[:, 0:3] if not np.isclose(all(vec), 0)] return list(bvec) def _rasb_to_bval_floats(in_rasb): + """ + Create a list of b-values from a rasb gradient table. + + Parameters + ---------- + in_rasb : str or os.pathlike + File path to a RAS-B gradient table. + """ import numpy as np - ras_b_mat = np.genfromtxt(in_rasb, delimiter='\t') - return [float(bval) for bval in ras_b_mat[:, 3] if bval > 0] \ No newline at end of file + + ras_b_mat = np.genfromtxt(in_rasb, delimiter="\t") + return [float(bval) for bval in ras_b_mat[:, 3] if bval > 0] diff --git a/dmriprep/workflows/dwi/emc.py b/dmriprep/workflows/dwi/emc.py index 56e6a85e..860651f6 100644 --- a/dmriprep/workflows/dwi/emc.py +++ b/dmriprep/workflows/dwi/emc.py @@ -1,119 +1,219 @@ import nipype.pipeline.engine as pe from nipype.interfaces import utility as niu, afni, ants -from pkg_resources import resource_filename as pkgrf from dmriprep.workflows.dwi.base import _list_squeeze -from dmriprep.interfaces.bids import DerivativesDataSink -from dmriprep.interfaces.images import SignalPrediction, CalculateCNR, ReorderOutputs, ImageMath, ExtractB0, \ - MatchTransforms, RescaleB0 -from dmriprep.interfaces.reports import IterationSummary, HMCReport -from dmriprep.interfaces.vectors import ReorientVectors, CheckGradientTable, CombineMotions -from dmriprep.utils.images import save_4d_to_3d, save_3d_to_4d, prune_b0s_from_dwis +from dmriprep.interfaces.images import ( + SignalPrediction, + CalculateCNR, + ReorderOutputs, + ExtractB0, + MatchTransforms, + RescaleB0, + CombineMotions, + ImageMath, +) +from dmriprep.interfaces.reports import IterationSummary, EMCReport +from dmriprep.interfaces.vectors import ReorientVectors, CheckGradientTable +from dmriprep.interfaces.register import Register, ApplyAffine +from dmriprep.utils.images import ( + save_4d_to_3d, + save_3d_to_4d, + prune_b0s_from_dwis, + average_images, +) from dmriprep.utils.vectors import _rasb_to_bvec_list, _rasb_to_bval_floats +from dmriprep.utils.register import average_affines +from pkg_resources import resource_filename as pkgrf + +def linear_alignment_workflow(transform, precision, iternum=0): + import_list = [ + "import warnings", + 'warnings.filterwarnings("ignore")', + "import sys", + "import os", + "import numpy as np", + "import networkx as nx", + "import nibabel as nb", + "from nipype.utils.filemanip import fname_presuffix", + ] -def linear_alignment_workflow(transform="Rigid", iternum=0, precision="precise"): - from nipype.interfaces import ants iteration_wf = pe.Workflow(name="iterative_alignment_%03d" % iternum) input_node_fields = ["image_paths", "template_image", "iteration_num"] linear_alignment_inputnode = pe.Node( - niu.IdentityInterface(fields=input_node_fields), name='linear_alignment_inputnode') + niu.IdentityInterface(fields=input_node_fields), + name="linear_alignment_inputnode", + ) linear_alignment_inputnode.inputs.iteration_num = iternum linear_alignment_outputnode = pe.Node( - niu.IdentityInterface(fields=["registered_image_paths", "affine_transforms", - "updated_template"]), name='linear_alignment_outputnode') - ants_settings = pkgrf( + niu.IdentityInterface( + fields=["registered_image_paths", "affine_transforms", "updated_template"] + ), + name="linear_alignment_outputnode", + ) + + settings = pkgrf( "dmriprep", - "config/emc_{precision}_{transform}.json".format(precision=precision, transform=transform)) - reg = ants.Registration(from_file=ants_settings) + "config/emc_{precision}_{transform}.json".format( + precision=precision, transform=transform + ), + ) + iter_reg = pe.MapNode( - reg, name="reg_%03d" % iternum, iterfield=["moving_image"]) + Register(settings, pipeline=[transform]), + name="reg_%03d" % iternum, + iterfield=["moving_image"], + ) - # Run the images through antsRegistration - iteration_wf.connect(linear_alignment_inputnode, "image_paths", iter_reg, "moving_image") - iteration_wf.connect(linear_alignment_inputnode, "template_image", iter_reg, "fixed_image") + # Run the images through affine registration + iteration_wf.connect( + linear_alignment_inputnode, "image_paths", iter_reg, "moving_image" + ) + iteration_wf.connect( + linear_alignment_inputnode, "template_image", iter_reg, "fixed_image" + ) # Average the images averaged_images = pe.Node( - ants.AverageImages(normalize=True, dimension=3), - name="averaged_images") + niu.Function( + input_names=["images"], + output_names=["output_average_image"], + function=average_images, + imports=import_list, + ), + name="averaged_images", + ) iteration_wf.connect(iter_reg, "warped_image", averaged_images, "images") # Apply the inverse to the average image transforms_to_list = pe.Node(niu.Merge(1), name="transforms_to_list") transforms_to_list.inputs.ravel_inputs = True - iteration_wf.connect(iter_reg, "forward_transforms", transforms_to_list, - "in1") - avg_affines = pe.Node(ants.AverageAffineTransform(), name="avg_affine") - avg_affines.inputs.dimension = 3 - avg_affines.inputs.output_affine_transform = "AveragedAffines.mat" + iteration_wf.connect(iter_reg, "forward_transforms", transforms_to_list, "in1") + avg_affines = pe.Node( + niu.Function( + input_names=["transforms"], + output_names=["average_affine_file"], + function=average_affines, + imports=import_list, + ), + name="avg_affine", + ) iteration_wf.connect(transforms_to_list, "out", avg_affines, "transforms") - invert_average = pe.Node(ants.ApplyTransforms(), name="invert_average") - invert_average.inputs.interpolation = "HammingWindowedSinc" - invert_average.inputs.invert_transform_flags = [True] + invert_average = pe.Node(ApplyAffine(), name="invert_average") + invert_average.inputs.invert_transform = True avg_to_list = pe.Node(niu.Merge(1), name="to_list") - iteration_wf.connect(avg_affines, "affine_transform", avg_to_list, "in1") - iteration_wf.connect(avg_to_list, "out", invert_average, "transforms") - iteration_wf.connect(averaged_images, "output_average_image", - invert_average, "input_image") - iteration_wf.connect(averaged_images, "output_average_image", - invert_average, "reference_image") - iteration_wf.connect(invert_average, "output_image", linear_alignment_outputnode, - "updated_template") - iteration_wf.connect(iter_reg, "forward_transforms", linear_alignment_outputnode, - "affine_transforms") - iteration_wf.connect(iter_reg, "warped_image", linear_alignment_outputnode, - "registered_image_paths") + iteration_wf.connect(avg_affines, "average_affine_file", avg_to_list, "in1") + iteration_wf.connect(avg_to_list, "out", invert_average, "transform_affine") + iteration_wf.connect( + averaged_images, "output_average_image", invert_average, "moving_image" + ) + iteration_wf.connect( + averaged_images, "output_average_image", invert_average, "fixed_image" + ) + iteration_wf.connect( + invert_average, "warped_image", linear_alignment_outputnode, "updated_template" + ) + iteration_wf.connect( + iter_reg, "forward_transforms", linear_alignment_outputnode, "affine_transforms" + ) + iteration_wf.connect( + iter_reg, "warped_image", linear_alignment_outputnode, "registered_image_paths" + ) return iteration_wf -def init_b0_emc_wf(transform="Rigid", metric="Mattes", num_iters=3, name="b0_emc_wf"): +def init_b0_emc_wf(num_iters=3, transform="rigid", name="b0_emc_wf"): b0_emc_wf = pe.Workflow(name=name) b0_emc_inputnode = pe.Node( - niu.IdentityInterface(fields=['b0_images', 'initial_template']), - name='b0_emc_inputnode') + niu.IdentityInterface(fields=["b0_images", "initial_template"]), + name="b0_emc_inputnode", + ) b0_emc_outputnode = pe.Node( - niu.IdentityInterface(fields=[ - "final_template", "forward_transforms", "iteration_templates", - "motion_params", "aligned_images"]), - name='b0_emc_outputnode') + niu.IdentityInterface( + fields=[ + "final_template", + "forward_transforms", + "iteration_templates", + "motion_params", + "aligned_images", + ] + ), + name="b0_emc_outputnode", + ) # Iteratively create a template # Store the registration targets iter_templates = pe.Node(niu.Merge(num_iters), name="iteration_templates") b0_emc_wf.connect(b0_emc_inputnode, "initial_template", iter_templates, "in1") + # Perform an initial coarse, rigid alignment to the b0 template initial_reg = linear_alignment_workflow( - transform=transform, - precision="coarse", - iternum=0) - b0_emc_wf.connect(b0_emc_inputnode, "initial_template", initial_reg, "linear_alignment_inputnode.template_image") - b0_emc_wf.connect(b0_emc_inputnode, "b0_images", initial_reg, "linear_alignment_inputnode.image_paths") + transform=transform, precision="coarse", iternum=0 + ) + b0_emc_wf.connect( + b0_emc_inputnode, + "initial_template", + initial_reg, + "linear_alignment_inputnode.template_image", + ) + b0_emc_wf.connect( + b0_emc_inputnode, + "b0_images", + initial_reg, + "linear_alignment_inputnode.image_paths", + ) reg_iters = [initial_reg] + + # Perform subsequent rigid alignment iterations for iternum in range(1, num_iters): reg_iters.append( linear_alignment_workflow( - transform=transform, - precision="precise", - iternum=iternum)) - b0_emc_wf.connect(reg_iters[-2], "linear_alignment_outputnode.updated_template", reg_iters[-1], - "linear_alignment_inputnode.template_image") - b0_emc_wf.connect(b0_emc_inputnode, "b0_images", reg_iters[-1], "linear_alignment_inputnode.image_paths") - b0_emc_wf.connect(reg_iters[-1], - "linear_alignment_outputnode.updated_template", iter_templates, "in%d" % (iternum + 1)) + transform=transform, precision="precise", iternum=iternum + ) + ) + b0_emc_wf.connect( + reg_iters[-2], + "linear_alignment_outputnode.updated_template", + reg_iters[-1], + "linear_alignment_inputnode.template_image", + ) + b0_emc_wf.connect( + b0_emc_inputnode, + "b0_images", + reg_iters[-1], + "linear_alignment_inputnode.image_paths", + ) + b0_emc_wf.connect( + reg_iters[-1], + "linear_alignment_outputnode.updated_template", + iter_templates, + "in%d" % (iternum + 1), + ) # Attach to outputs # The last iteration aligned to the output from the second-to-last - b0_emc_wf.connect(reg_iters[-2], "linear_alignment_outputnode.updated_template", b0_emc_outputnode, - "final_template") - b0_emc_wf.connect(reg_iters[-1], "linear_alignment_outputnode.affine_transforms", b0_emc_outputnode, - "forward_transforms") - b0_emc_wf.connect(reg_iters[-1], "linear_alignment_outputnode.registered_image_paths", b0_emc_outputnode, - "aligned_images") + b0_emc_wf.connect( + reg_iters[-2], + "linear_alignment_outputnode.updated_template", + b0_emc_outputnode, + "final_template", + ) + b0_emc_wf.connect( + reg_iters[-1], + "linear_alignment_outputnode.affine_transforms", + b0_emc_outputnode, + "forward_transforms", + ) + b0_emc_wf.connect( + reg_iters[-1], + "linear_alignment_outputnode.registered_image_paths", + b0_emc_outputnode, + "aligned_images", + ) b0_emc_wf.connect(iter_templates, "out", b0_emc_outputnode, "iteration_templates") return b0_emc_wf @@ -121,32 +221,39 @@ def init_b0_emc_wf(transform="Rigid", metric="Mattes", num_iters=3, name="b0_emc def init_enhance_and_skullstrip_template_mask_wf(name): eastm_workflow = pe.Workflow(name=name) - eastm_inputnode = pe.Node(niu.IdentityInterface(fields=['in_file']), - name='eastm_inputnode') - eastm_outputnode = pe.Node(niu.IdentityInterface(fields=[ - 'mask_file', 'skull_stripped_file', 'bias_corrected_file']), name='eastm_outputnode') + eastm_inputnode = pe.Node( + niu.IdentityInterface(fields=["in_file"]), name="eastm_inputnode" + ) + eastm_outputnode = pe.Node( + niu.IdentityInterface( + fields=["mask_file", "skull_stripped_file", "bias_corrected_file"] + ), + name="eastm_outputnode", + ) # Truncate intensity values so they're OK for N4 truncate_values = pe.Node( - ImageMath(dimension=3, - operation="TruncateImageIntensity", - secondary_arg="0.0 0.98 512"), - name="truncate_values") + ImageMath( + dimension=3, + operation="TruncateImageIntensity", + secondary_arg="0.0 0.98 512", + ), + name="truncate_values", + ) # Truncate intensity values for creating a mask # (there are many high outliers in b=0 images) truncate_values_for_masking = pe.Node( - ImageMath(dimension=3, - operation="TruncateImageIntensity", - secondary_arg="0.0 0.9 512"), - name="truncate_values_for_masking") + ImageMath( + dimension=3, operation="TruncateImageIntensity", secondary_arg="0.0 0.9 512" + ), + name="truncate_values_for_masking", + ) # N4 will break if any negative values are present. rescale_image = pe.Node( - ImageMath(dimension=3, - operation="RescaleImage", - secondary_arg="0 1000"), - name="rescale_image" + ImageMath(dimension=3, operation="RescaleImage", secondary_arg="0 1000"), + name="rescale_image", ) # Run N4 normally, force num_threads=1 for stability (images are small, no need for >1) @@ -157,93 +264,123 @@ def init_enhance_and_skullstrip_template_mask_wf(name): convergence_threshold=1e-6, bspline_order=3, bspline_fitting_distance=150, - copy_header=True), - name='n4_correct', n_procs=1) + copy_header=True, + ), + name="n4_correct", + n_procs=1, + ) # Sharpen the b0 ref sharpen_image = pe.Node( - ImageMath(dimension=3, - operation="Sharpen"), - name="sharpen_image") + ImageMath(dimension=3, operation="Sharpen"), name="sharpen_image" + ) # Basic mask - initial_mask = pe.Node(afni.Automask(outputtype="NIFTI_GZ"), - name="initial_mask") + initial_mask = pe.Node(afni.Automask(outputtype="NIFTI_GZ"), name="initial_mask") # Fill holes left by Automask fill_holes = pe.Node( - ImageMath(dimension=3, - operation='FillHoles', - secondary_arg='2'), - name='fill_holes') + ImageMath(dimension=3, operation="FillHoles", secondary_arg="2"), + name="fill_holes", + ) # Dilate before smoothing dilate_mask = pe.Node( - ImageMath(dimension=3, - operation='MD', - secondary_arg='1'), - name='dilate_mask') + ImageMath(dimension=3, operation="MD", secondary_arg="1"), name="dilate_mask" + ) # Smooth the mask and use it as a weight for N4 smooth_mask = pe.Node( - ImageMath(dimension=3, - operation='G', - secondary_arg='4'), - name='smooth_mask') + ImageMath(dimension=3, operation="G", secondary_arg="4"), name="smooth_mask" + ) # Make a "soft" skull-stripped image apply_mask = pe.Node( - ants.MultiplyImages(dimension=3, output_product_image="SkullStrippedRef.nii.gz"), - name="apply_mask") - - eastm_workflow.connect([ - (eastm_inputnode, truncate_values, [('in_file', 'in_file')]), - (truncate_values, rescale_image, [('out_file', 'in_file')]), - (eastm_inputnode, truncate_values_for_masking, [('in_file', 'in_file')]), - (truncate_values_for_masking, initial_mask, [('out_file', 'in_file')]), - (initial_mask, fill_holes, [('out_file', 'in_file')]), - (fill_holes, dilate_mask, [('out_file', 'in_file')]), - (dilate_mask, smooth_mask, [('out_file', 'in_file')]), - (rescale_image, n4_correct, [('out_file', 'input_image')]), - (smooth_mask, n4_correct, [('out_file', 'weight_image')]), - (n4_correct, sharpen_image, [('output_image', 'in_file')]), - (sharpen_image, eastm_outputnode, [('out_file', 'bias_corrected_file')]), - (sharpen_image, apply_mask, [('out_file', 'first_input')]), - (smooth_mask, apply_mask, [('out_file', 'second_input')]), - (apply_mask, eastm_outputnode, [('output_product_image', 'skull_stripped_file')]), - (fill_holes, eastm_outputnode, [('out_file', 'mask_file')]) - ]) + ants.MultiplyImages( + dimension=3, output_product_image="SkullStrippedRef.nii.gz" + ), + name="apply_mask", + ) + + eastm_workflow.connect( + [ + (eastm_inputnode, truncate_values, [("in_file", "in_file")]), + (truncate_values, rescale_image, [("out_file", "in_file")]), + (eastm_inputnode, truncate_values_for_masking, [("in_file", "in_file")]), + (truncate_values_for_masking, initial_mask, [("out_file", "in_file")]), + (initial_mask, fill_holes, [("out_file", "in_file")]), + (fill_holes, dilate_mask, [("out_file", "in_file")]), + (dilate_mask, smooth_mask, [("out_file", "in_file")]), + (rescale_image, n4_correct, [("out_file", "input_image")]), + (smooth_mask, n4_correct, [("out_file", "weight_image")]), + (n4_correct, sharpen_image, [("output_image", "in_file")]), + (sharpen_image, eastm_outputnode, [("out_file", "bias_corrected_file")]), + (sharpen_image, apply_mask, [("out_file", "first_input")]), + (smooth_mask, apply_mask, [("out_file", "second_input")]), + ( + apply_mask, + eastm_outputnode, + [("output_product_image", "skull_stripped_file")], + ), + (fill_holes, eastm_outputnode, [("out_file", "mask_file")]), + ] + ) return eastm_workflow -def init_emc_model_iteration_wf(transform, precision="coarse", name="emc_model_iter0"): +def init_emc_model_iteration_wf( + precision, transform, prune_b0s, model_name, name="emc_model_iter0" +): emc_model_iter_workflow = pe.Workflow(name=name) - emc_model_iteration_inputnode = pe.Node(niu.IdentityInterface(fields=['original_dwi_files', 'original_rasb_file', - 'aligned_dwi_files', 'aligned_vectors', - 'b0_median', 'b0_mask', 'b0_indices']), - name='emc_model_iteration_inputnode') + emc_model_iteration_inputnode = pe.Node( + niu.IdentityInterface( + fields=[ + "original_dwi_files", + "original_rasb_file", + "aligned_dwi_files", + "aligned_vectors", + "b0_median", + "b0_mask", + "b0_indices", + ] + ), + name="emc_model_iteration_inputnode", + ) emc_model_iteration_outputnode = pe.Node( niu.IdentityInterface( - fields=['emc_transforms', 'aligned_dwis', 'aligned_vectors', - 'predicted_dwis', 'motion_params']), - name='emc_model_iteration_outputnode') - - ants_settings = pkgrf( - "dmriprep", - "config/emc_{precision}_{transform}.json".format(precision=precision, - transform=transform)) + fields=[ + "emc_transforms", + "aligned_dwis", + "aligned_vectors", + "predicted_dwis", + "motion_params", + ] + ), + name="emc_model_iteration_outputnode", + ) - predict_dwis = pe.MapNode(SignalPrediction(), - iterfield=['bval_to_predict', 'bvec_to_predict'], - name="predict_dwis") + # Predict signal from a given coordinate on the sphere + predict_dwis = pe.MapNode( + SignalPrediction(prune_b0s=prune_b0s, model_name=model_name), + iterfield=["bval_to_predict", "bvec_to_predict"], + name="predict_dwis", + ) predict_dwis.synchronize = True # Register original images to the predicted images - register_to_predicted = pe.MapNode(ants.Registration(from_file=ants_settings), - iterfield=['fixed_image', 'moving_image'], - name='register_to_predicted') + settings = pkgrf( + "dmriprep", + "config/emc_{precision}_{transform}.json".format( + precision=precision, transform="affine" + ), + ) + register_to_predicted = pe.MapNode( + Register(settings, pipeline=transform), + iterfield=["moving_image", "fixed_image"], + name="register_to_predicted", + ) register_to_predicted.synchronize = True # Apply new transforms to vectors @@ -252,260 +389,594 @@ def init_emc_model_iteration_wf(transform, precision="coarse", name="emc_model_i # Summarize the motion calculate_motion = pe.Node(CombineMotions(), name="calculate_motion") - emc_model_iter_workflow.connect([ - # Send inputs to DWI prediction - (emc_model_iteration_inputnode, predict_dwis, [('aligned_dwi_files', 'aligned_dwi_files'), - ('aligned_vectors', 'aligned_vectors'), - ('b0_indices', 'b0_indices'), - ('b0_median', 'b0_median'), - ('b0_mask', 'b0_mask'), - (('aligned_vectors', _rasb_to_bvec_list), 'bvec_to_predict'), - (('aligned_vectors', _rasb_to_bval_floats), 'bval_to_predict')]), - (predict_dwis, register_to_predicted, [('predicted_image', 'fixed_image')]), - (emc_model_iteration_inputnode, register_to_predicted, [ - ('original_dwi_files', 'moving_image'), - ('b0_mask', 'fixed_image_mask')]), - (register_to_predicted, calculate_motion, [ - (('forward_transforms', _list_squeeze), 'transform_files')]), - (emc_model_iteration_inputnode, calculate_motion, [('original_dwi_files', 'source_files'), - ('b0_median', 'ref_file')]), - (calculate_motion, emc_model_iteration_outputnode, [('motion_file', 'motion_params')]), - (register_to_predicted, post_vector_transforms, [ - (('forward_transforms', _list_squeeze), 'affines')]), - (emc_model_iteration_inputnode, post_vector_transforms, [('original_rasb_file', 'rasb_file')]), - (predict_dwis, emc_model_iteration_outputnode, [('predicted_image', 'predicted_dwis')]), - (post_vector_transforms, emc_model_iteration_outputnode, [('out_rasb', 'aligned_vectors')]), - (register_to_predicted, emc_model_iteration_outputnode, [('warped_image', 'aligned_dwis'), - ('forward_transforms', 'emc_transforms')]) - ]) + emc_model_iter_workflow.connect( + [ + # Send inputs to DWI prediction + ( + emc_model_iteration_inputnode, + predict_dwis, + [ + ("aligned_dwi_files", "aligned_dwi_files"), + ("aligned_vectors", "aligned_vectors"), + ("b0_indices", "b0_indices"), + ("b0_median", "b0_median"), + ("b0_mask", "b0_mask"), + (("aligned_vectors", _rasb_to_bvec_list), "bvec_to_predict"), + (("aligned_vectors", _rasb_to_bval_floats), "bval_to_predict"), + ], + ), + (predict_dwis, register_to_predicted, [("predicted_image", "fixed_image")]), + ( + emc_model_iteration_inputnode, + register_to_predicted, + [("original_dwi_files", "moving_image")], + ), + ( + register_to_predicted, + calculate_motion, + [(("forward_transforms", _list_squeeze), "transform_files")], + ), + ( + emc_model_iteration_inputnode, + calculate_motion, + [("original_dwi_files", "source_files"), ("b0_median", "ref_file")], + ), + ( + calculate_motion, + emc_model_iteration_outputnode, + [("motion_file", "motion_params")], + ), + ( + register_to_predicted, + post_vector_transforms, + [(("forward_transforms", _list_squeeze), "affines")], + ), + ( + emc_model_iteration_inputnode, + post_vector_transforms, + [("aligned_vectors", "rasb_file"), ("b0_median", "dwi_file")], + ), + ( + predict_dwis, + emc_model_iteration_outputnode, + [("predicted_image", "predicted_dwis")], + ), + ( + post_vector_transforms, + emc_model_iteration_outputnode, + [("out_rasb", "aligned_vectors")], + ), + ( + register_to_predicted, + emc_model_iteration_outputnode, + [ + ("warped_image", "aligned_dwis"), + ("forward_transforms", "emc_transforms"), + ], + ), + ] + ) return emc_model_iter_workflow -def init_dwi_model_emc_wf(transform, num_iters=2, name='dwi_model_emc_wf'): +def init_dwi_model_emc_wf(num_iters=2, name="dwi_model_emc_wf"): dwi_model_emc_wf_workflow = pe.Workflow(name=name) inputnode = pe.Node( niu.IdentityInterface( - fields=['original_dwi_files', 'original_rasb_file', 'aligned_dwi_files', 'b0_indices', - 'initial_transforms', 'aligned_vectors', 'b0_median', 'b0_mask', 'warped_b0_images']), - name='dwi_model_emc_inputnode') + fields=[ + "original_dwi_files", + "original_rasb_file", + "aligned_dwi_files", + "b0_indices", + "initial_transforms", + "aligned_vectors", + "b0_median", + "b0_mask", + "warped_b0_images", + ] + ), + name="dwi_model_emc_inputnode", + ) outputnode = pe.Node( niu.IdentityInterface( - fields=['emc_transforms', 'model_predicted_images', 'cnr_image', - 'optimization_data']), - name='dwi_model_emc_outputnode') + fields=[ + "emc_transforms", + "model_predicted_images", + "cnr_image", + "optimization_data", + ] + ), + name="dwi_model_emc_outputnode", + ) - # Start building and connecting the model iterations + # Instantiate an initial LOO prediction workflow initial_model_iteration = init_emc_model_iteration_wf( - transform, precision="coarse", name="initial_model_iteration") + precision="coarse", + transform=["rigid", "affine"], + prune_b0s=True, + model_name="tensor", + name="initial_model_iteration", + ) # Collect motion estimates across iterations - collect_motion_params = pe.Node(niu.Merge(num_iters), - name="collect_motion_params") - - dwi_model_emc_wf_workflow.connect([ - # Connect the first iteration - (inputnode, initial_model_iteration, [ - ('original_dwi_files', 'emc_model_iteration_inputnode.original_dwi_files'), - ('b0_median', 'emc_model_iteration_inputnode.b0_median'), - ('b0_mask', 'emc_model_iteration_inputnode.b0_mask'), - ('b0_indices', 'emc_model_iteration_inputnode.b0_indices'), - ('aligned_vectors', 'emc_model_iteration_inputnode.aligned_vectors'), - ('aligned_dwi_files', 'emc_model_iteration_inputnode.aligned_dwi_files')]), - # (initial_model_iteration, collect_motion_params, [ - # ('emc_model_iteration_outputnode.motion_params', 'in1')]) - ]) + collect_motion_params = pe.Node(niu.Merge(num_iters), name="collect_motion_params") + + dwi_model_emc_wf_workflow.connect( + [ + # Connect the first iteration + ( + inputnode, + initial_model_iteration, + [ + ( + "original_dwi_files", + "emc_model_iteration_inputnode.original_dwi_files", + ), + ( + "original_rasb_file", + "emc_model_iteration_inputnode.original_rasb_file", + ), + ("b0_median", "emc_model_iteration_inputnode.b0_median"), + ("b0_mask", "emc_model_iteration_inputnode.b0_mask"), + ("b0_indices", "emc_model_iteration_inputnode.b0_indices"), + ( + "aligned_vectors", + "emc_model_iteration_inputnode.aligned_vectors", + ), + ( + "aligned_dwi_files", + "emc_model_iteration_inputnode.aligned_dwi_files", + ), + ], + ), + ( + initial_model_iteration, + collect_motion_params, + [("emc_model_iteration_outputnode.motion_params", "in1")], + ), + ] + ) model_iterations = [initial_model_iteration] + # Perform additional iterations of LOO prediction for iteration_num in range(num_iters - 1): - iteration_name = 'HMC_iteration%03d' % (iteration_num + 1) - motion_key = 'in%d' % (iteration_num + 2) + iteration_name = "EMC_iteration%03d" % (iteration_num + 1) + motion_key = "in%d" % (iteration_num + 2) model_iterations.append( - init_emc_model_iteration_wf(transform=transform, - precision="precise", - name=iteration_name) + init_emc_model_iteration_wf( + precision="precise", + transform=["rigid", "affine"], + prune_b0s=False, + model_name="sfm", + name=iteration_name, + ) + ) + dwi_model_emc_wf_workflow.connect( + [ + ( + model_iterations[-2], + model_iterations[-1], + [ + ( + "emc_model_iteration_outputnode.aligned_dwis", + "emc_model_iteration_inputnode.aligned_dwi_files", + ), + ( + "emc_model_iteration_outputnode.aligned_vectors", + "emc_model_iteration_inputnode.aligned_vectors", + ), + ], + ), + ( + inputnode, + model_iterations[-1], + [ + ("b0_mask", "emc_model_iteration_inputnode.b0_mask"), + ("b0_indices", "emc_model_iteration_inputnode.b0_indices"), + ( + "original_dwi_files", + "emc_model_iteration_inputnode.original_dwi_files", + ), + ( + "original_rasb_file", + "emc_model_iteration_inputnode.original_rasb_file", + ), + ("b0_median", "emc_model_iteration_inputnode.b0_median"), + ], + ), + ( + model_iterations[-1], + collect_motion_params, + [("emc_model_iteration_outputnode.motion_params", motion_key)], + ), + ] ) - dwi_model_emc_wf_workflow.connect([ - (model_iterations[-2], model_iterations[-1], [ - ('emc_model_iteration_outputnode.aligned_dwis', - 'emc_model_iteration_inputnode.aligned_dwi_files'), - ('emc_model_iteration_outputnode.aligned_vectors', - 'emc_model_iteration_inputnode.aligned_vectors')]), - (inputnode, model_iterations[-1], [ - ('b0_mask', 'emc_model_iteration_inputnode.b0_mask'), - ('b0_indices', 'emc_model_iteration_inputnode.b0_indices'), - ('original_dwi_files', 'emc_model_iteration_inputnode.original_dwi_files'), - ('original_rasb_file', 'emc_model_iteration_inputnode.original_rasb_file'), - ('b0_median', 'emc_model_iteration_inputnode.b0_median')]), - (model_iterations[-1], collect_motion_params, [ - ('emc_model_iteration_outputnode.motion_params', motion_key)]) - ]) # Return to the original, b0-interspersed ordering - reorder_dwi_xforms = pe.Node(ReorderOutputs(), name='reorder_dwi_xforms') + reorder_dwi_xforms = pe.Node(ReorderOutputs(), name="reorder_dwi_xforms") # Create a report: - emc_report = pe.Node(HMCReport(), name='emc_report') - ds_report_emc_gif = pe.Node( - DerivativesDataSink(suffix="emc_animation"), name='ds_report_emc_gif', - mem_gb=1, run_without_submitting=True) - - calculate_cnr = pe.Node(CalculateCNR(), name='calculate_cnr') - - # if num_iters > 1: - # summarize_iterations = pe.Node(IterationSummary(), name='summarize_iterations') - # ds_report_iteration_plot = pe.Node( - # DerivativesDataSink(suffix="emc_iterdata"), name='ds_report_iteration_plot', - # mem_gb=0.1, run_without_submitting=True) - # dwi_model_emc_wf_workflow.connect([ - # (collect_motion_params, summarize_iterations, [ - # ('out', 'collected_motion_files')]), - # (summarize_iterations, ds_report_iteration_plot, [ - # ('plot_file', 'in_file')]), - # (summarize_iterations, outputnode, [ - # ('iteration_summary_file', 'optimization_data')]), - # (summarize_iterations, emc_report, [ - # ('iteration_summary_file', 'iteration_summary')]) - # ]) - - dwi_model_emc_wf_workflow.connect([ - (model_iterations[-1], reorder_dwi_xforms, [ - ('emc_model_iteration_outputnode.emc_transforms', 'model_based_transforms'), - ('emc_model_iteration_outputnode.predicted_dwis', 'model_predicted_images'), - ('emc_model_iteration_outputnode.aligned_dwis', 'warped_dwi_images')]), - (inputnode, reorder_dwi_xforms, [ - ('b0_median', 'b0_median'), - ('warped_b0_images', 'warped_b0_images'), - ('b0_indices', 'b0_indices'), - ('initial_transforms', 'initial_transforms')]), - (reorder_dwi_xforms, outputnode, [ - ('emc_warped_images', 'aligned_dwis'), - ('full_transforms', 'emc_transforms'), - ('full_predicted_dwi_series', 'model_predicted_images')]), - (inputnode, emc_report, [('original_dwi_files', 'original_images')]), - (reorder_dwi_xforms, calculate_cnr, [ - ('emc_warped_images', 'emc_warped_images'), - ('full_predicted_dwi_series', 'predicted_images')]), - (inputnode, calculate_cnr, [('b0_mask', 'mask_image')]), - (calculate_cnr, outputnode, [('cnr_image', 'cnr_image')]), - (reorder_dwi_xforms, emc_report, [ - ('full_predicted_dwi_series', 'model_predicted_images'), - ('emc_warped_images', 'registered_images')]), - (emc_report, ds_report_emc_gif, [('plot_file', 'in_file')]) - ]) + emc_report = pe.Node(EMCReport(), name="emc_report") + + calculate_cnr = pe.Node(CalculateCNR(), name="calculate_cnr") + + if num_iters > 1: + summarize_iterations = pe.Node(IterationSummary(), name="summarize_iterations") + dwi_model_emc_wf_workflow.connect( + [ + ( + collect_motion_params, + summarize_iterations, + [("out", "collected_motion_files")], + ), + ( + summarize_iterations, + outputnode, + [("iteration_summary_file", "optimization_data")], + ), + ( + summarize_iterations, + emc_report, + [("iteration_summary_file", "iteration_summary")], + ), + ] + ) + + dwi_model_emc_wf_workflow.connect( + [ + ( + model_iterations[-1], + reorder_dwi_xforms, + [ + ( + "emc_model_iteration_outputnode.emc_transforms", + "model_based_transforms", + ), + ( + "emc_model_iteration_outputnode.predicted_dwis", + "model_predicted_images", + ), + ( + "emc_model_iteration_outputnode.aligned_dwis", + "warped_dwi_images", + ), + ], + ), + ( + inputnode, + reorder_dwi_xforms, + [ + ("b0_median", "b0_median"), + ("warped_b0_images", "warped_b0_images"), + ("b0_indices", "b0_indices"), + ("initial_transforms", "initial_transforms"), + ], + ), + ( + reorder_dwi_xforms, + outputnode, + [ + ("emc_warped_images", "aligned_dwis"), + ("full_transforms", "emc_transforms"), + ("full_predicted_dwi_series", "model_predicted_images"), + ], + ), + (inputnode, emc_report, [("original_dwi_files", "original_images")]), + ( + reorder_dwi_xforms, + calculate_cnr, + [ + ("emc_warped_images", "emc_warped_images"), + ("full_predicted_dwi_series", "predicted_images"), + ], + ), + (inputnode, calculate_cnr, [("b0_mask", "mask_image")]), + (calculate_cnr, outputnode, [("cnr_image", "cnr_image")]), + ( + reorder_dwi_xforms, + emc_report, + [ + ("full_predicted_dwi_series", "model_predicted_images"), + ("emc_warped_images", "registered_images"), + ], + ) + ] + ) return dwi_model_emc_wf_workflow def init_emc_wf(name, mem_gb=3, omp_nthreads=8): - import_list = ["import warnings", "warnings.filterwarnings(\"ignore\")", "import sys", "import os", - "import numpy as np", "import networkx as nx", "import nibabel as nb", - "from nipype.utils.filemanip import fname_presuffix"] + import_list = [ + "import warnings", + 'warnings.filterwarnings("ignore")', + "import sys", + "import os", + "import numpy as np", + "import networkx as nx", + "import nibabel as nb", + "from nipype.utils.filemanip import fname_presuffix", + ] emc_wf = pe.Workflow(name=name) - meta_inputnode = pe.Node(niu.IdentityInterface(fields=['dwi_file', 'in_bval', 'in_bvec', 'b0_template']), - name='meta_inputnode') + meta_inputnode = pe.Node( + niu.IdentityInterface(fields=["dwi_file", "in_bval", "in_bvec", "b0_template"]), + name="meta_inputnode", + ) - vectors_node = pe.Node(CheckGradientTable(), name='emc_vectors_node') + # Instantiate vectors object + vectors_node = pe.Node(CheckGradientTable(), name="emc_vectors_node") + # Extract B0s extract_b0s_node = pe.Node(ExtractB0(), name="extract_b0_node") - split_b0s_node = pe.Node(niu.Function(input_names=['in_file'], output_names=['out_files'], - function=save_4d_to_3d, imports=import_list), - name="split_b0s_node") + # Split B0s into separate images + split_b0s_node = pe.Node( + niu.Function( + input_names=["in_file"], + output_names=["out_files"], + function=save_4d_to_3d, + imports=import_list, + ), + name="split_b0s_node", + ) - prune_b0s_from_dwis_node = pe.Node(niu.Function(input_names=['in_files', 'b0_ixs'], output_names=['out_files'], - function=prune_b0s_from_dwis, imports=import_list), - name="prune_b0s_from_dwis_node") + # Remove b0s from dwi series + prune_b0s_from_dwis_node = pe.Node( + niu.Function( + input_names=["in_files", "b0_ixs"], + output_names=["out_files"], + function=prune_b0s_from_dwis, + imports=import_list, + ), + name="prune_b0s_from_dwis_node", + ) - split_dwis_node = pe.Node(niu.Function(input_names=['in_file'], output_names=['out_files'], - function=save_4d_to_3d, imports=import_list), - name="split_dwis_node") + # Split dwi series into 3d images + split_dwis_node = pe.Node( + niu.Function( + input_names=["in_file"], + output_names=["out_files"], + function=save_4d_to_3d, + imports=import_list, + ), + name="split_dwis_node", + ) - merge_b0s_node = pe.Node(niu.Function(input_names=['in_files'], output_names=['out_file'], - function=save_3d_to_4d, imports=import_list), - name="merge_b0s_node") + # Merge B0s into a single 4d image + merge_b0s_node = pe.Node( + niu.Function( + input_names=["in_files"], + output_names=["out_file"], + function=save_3d_to_4d, + imports=import_list, + ), + name="merge_b0s_node", + ) + # Order affine transforms to correspond with split dwi files match_transforms_node = pe.Node(MatchTransforms(), name="match_transforms_node") - b0_emc_wf = init_b0_emc_wf(transform="Rigid") + # Create a skull-stripped and enhanced b0 + eastm_wf = init_enhance_and_skullstrip_template_mask_wf( + name="enhance_and_skullstrip_template_mask_wf" + ) - eastm_wf = init_enhance_and_skullstrip_template_mask_wf(name='enhance_and_skullstrip_template_mask_wf') + # Instantiate b0 eddy correction + b0_emc_wf = init_b0_emc_wf() # Initialize with the transforms provided - b0_based_image_transforms = pe.MapNode(ants.ApplyTransforms(interpolation="BSpline"), - iterfield=['input_image', 'transforms'], - name="b0_based_image_transforms") + b0_based_image_transforms = pe.MapNode( + ApplyAffine(), + iterfield=["moving_image", "transform_affine"], + name="b0_based_image_transforms", + ) + # Rotate vectors - b0_based_vector_transforms = pe.Node(ReorientVectors(), name="b0_based_vector_transforms") + b0_based_vector_transforms = pe.Node( + ReorientVectors(), name="b0_based_vector_transforms" + ) # Grab the median of the aligned B0 images - b0_median = pe.Node(RescaleB0(), name='b0_median') + b0_median = pe.Node(RescaleB0(), name="b0_median") # Do model-based motion correction - dwi_model_emc_wf = init_dwi_model_emc_wf(transform='Affine', num_iters=2) + dwi_model_emc_wf = init_dwi_model_emc_wf(num_iters=2) # Warp the modeled images into non-motion-corrected space uncorrect_model_images = pe.MapNode( - ants.ApplyTransforms(invert_transform_flags=[True], - interpolation='LanczosWindowedSinc'), - iterfield=['input_image', 'reference_image', 'transforms'], - name='uncorrect_model_images') + ApplyAffine(), + iterfield=["moving_image", "transform_affine"], + name="uncorrect_model_images", + ) + uncorrect_model_images.inputs.invert_transform = True + + # Save to 4d image + merge_EMC_corrected_dwis_node = pe.Node( + niu.Function( + input_names=["in_files"], + output_names=["out_file"], + function=save_3d_to_4d, + imports=import_list, + ), + name="merge_EMC_corrected_dwis_node", + ) meta_outputnode = pe.Node( niu.IdentityInterface( - fields=["final_template", "forward_transforms", "noise_free_dwis", - "cnr_image", "optimization_data"]), - name='meta_outputnode') - - emc_wf.connect([ - # Create initial transforms to apply to bias-reduced B0 template - (meta_inputnode, vectors_node, [('dwi_file', 'dwi_file'), - ('in_bval', 'in_bval'), - ('in_bvec', 'in_bvec')]), - (meta_inputnode, extract_b0s_node, [('dwi_file', 'in_file')]), - (vectors_node, extract_b0s_node, [('b0_ixs', 'b0_ixs')]), - (extract_b0s_node, split_b0s_node, [('out_file', 'in_file')]), - (meta_inputnode, split_dwis_node, [('dwi_file', 'in_file')]), - (split_b0s_node, b0_emc_wf, [('out_files', 'b0_emc_inputnode.b0_images')]), - (meta_inputnode, b0_emc_wf, [('b0_template', 'b0_emc_inputnode.initial_template')]), - (b0_emc_wf, match_transforms_node, [(('b0_emc_outputnode.forward_transforms', _list_squeeze), 'transforms')]), - (vectors_node, match_transforms_node, [('b0_ixs', 'b0_indices')]), - (split_dwis_node, match_transforms_node, [('out_files', 'dwi_files')]), - (b0_emc_wf, eastm_wf, - [('b0_emc_outputnode.final_template', 'eastm_inputnode.in_file')]), - (match_transforms_node, b0_based_vector_transforms, [('transforms', 'affines')]), - (vectors_node, b0_based_vector_transforms, [('out_rasb', 'rasb_file')]), - (split_dwis_node, b0_based_image_transforms, [('out_files', 'input_image')]), - (match_transforms_node, b0_based_image_transforms, [('transforms', 'transforms')]), - (b0_emc_wf, merge_b0s_node, [('b0_emc_outputnode.aligned_images', 'in_files')]), - (merge_b0s_node, b0_median, [('out_file', 'in_file')]), - (eastm_wf, b0_median, [('eastm_outputnode.mask_file', 'mask_file')]), - (b0_median, b0_based_image_transforms, [('out_ref', 'reference_image')]), - - # Perform signal prediction from vectors - (vectors_node, prune_b0s_from_dwis_node, [('b0_ixs', 'b0_ixs')]), - (split_dwis_node, prune_b0s_from_dwis_node, [('out_files', 'in_files')]), - (prune_b0s_from_dwis_node, dwi_model_emc_wf, [('out_files', 'dwi_model_emc_inputnode.original_dwi_files')]), - (vectors_node, dwi_model_emc_wf, [('out_rasb', 'dwi_model_emc_inputnode.original_rasb_file')]), - (b0_based_image_transforms, dwi_model_emc_wf, [('output_image', 'dwi_model_emc_inputnode.aligned_dwi_files')]), - (b0_based_vector_transforms, dwi_model_emc_wf, [('out_rasb', 'dwi_model_emc_inputnode.aligned_vectors')]), - (b0_median, dwi_model_emc_wf, [('out_ref', 'dwi_model_emc_inputnode.b0_median')]), - (eastm_wf, dwi_model_emc_wf, - [('eastm_outputnode.mask_file', 'dwi_model_emc_inputnode.b0_mask')]), - (vectors_node, dwi_model_emc_wf, [('b0_ixs', 'dwi_model_emc_inputnode.b0_indices')]), - (match_transforms_node, dwi_model_emc_wf, [('transforms', 'dwi_model_emc_inputnode.initial_transforms')]), - (b0_emc_wf, dwi_model_emc_wf, - [('b0_emc_outputnode.aligned_images', 'dwi_model_emc_inputnode.warped_b0_images')]), - (b0_emc_wf, meta_outputnode, [('b0_emc_outputnode.final_template', 'final_template')]), - (dwi_model_emc_wf, meta_outputnode, [('dwi_model_emc_outputnode.emc_transforms', 'forward_transforms'), - ('dwi_model_emc_outputnode.optimization_data', 'optimization_data'), - ('dwi_model_emc_outputnode.cnr_image', 'cnr_image')]), - (b0_emc_wf, uncorrect_model_images, [('b0_emc_outputnode.final_template', 'reference_image')]), - (dwi_model_emc_wf, uncorrect_model_images, [('dwi_model_emc_outputnode.emc_transforms', 'transforms')]), - (split_dwis_node, uncorrect_model_images, [('out_files', 'input_image')]), - (uncorrect_model_images, meta_outputnode, [('output_image', 'noise_free_dwis')]) - ]) + fields=[ + "final_emc_4d_series" + "final_template", + "forward_transforms", + "noise_free_dwis", + "cnr_image", + "optimization_data", + ] + ), + name="meta_outputnode", + ) + + emc_wf.connect( + [ + # Create initial transforms to apply to bias-reduced B0 template + ( + meta_inputnode, + vectors_node, + [ + ("dwi_file", "dwi_file"), + ("in_bval", "in_bval"), + ("in_bvec", "in_bvec"), + ], + ), + (meta_inputnode, extract_b0s_node, [("dwi_file", "in_file")]), + (vectors_node, extract_b0s_node, [("b0_ixs", "b0_ixs")]), + (extract_b0s_node, split_b0s_node, [("out_file", "in_file")]), + (meta_inputnode, split_dwis_node, [("dwi_file", "in_file")]), + (split_b0s_node, b0_emc_wf, [("out_files", "b0_emc_inputnode.b0_images")]), + ( + meta_inputnode, + b0_emc_wf, + [("b0_template", "b0_emc_inputnode.initial_template")], + ), + ( + b0_emc_wf, + match_transforms_node, + [ + ( + ("b0_emc_outputnode.forward_transforms", _list_squeeze), + "transforms", + ) + ], + ), + (vectors_node, match_transforms_node, [("b0_ixs", "b0_indices")]), + (split_dwis_node, match_transforms_node, [("out_files", "dwi_files")]), + ( + b0_emc_wf, + eastm_wf, + [("b0_emc_outputnode.final_template", "eastm_inputnode.in_file")], + ), + (meta_inputnode, b0_based_vector_transforms, [("dwi_file", "dwi_file")]), + ( + match_transforms_node, + b0_based_vector_transforms, + [("transforms", "affines")], + ), + (vectors_node, b0_based_vector_transforms, [("out_rasb", "rasb_file")]), + ( + split_dwis_node, + b0_based_image_transforms, + [("out_files", "moving_image")], + ), + ( + match_transforms_node, + b0_based_image_transforms, + [("transforms", "transform_affine")], + ), + ( + b0_emc_wf, + merge_b0s_node, + [("b0_emc_outputnode.aligned_images", "in_files")], + ), + (merge_b0s_node, b0_median, [("out_file", "in_file")]), + (eastm_wf, b0_median, [("eastm_outputnode.mask_file", "mask_file")]), + (b0_median, b0_based_image_transforms, [("out_ref", "fixed_image")]), + + # Perform signal prediction from vectors + (vectors_node, prune_b0s_from_dwis_node, [("b0_ixs", "b0_ixs")]), + (split_dwis_node, prune_b0s_from_dwis_node, [("out_files", "in_files")]), + ( + prune_b0s_from_dwis_node, + dwi_model_emc_wf, + [("out_files", "dwi_model_emc_inputnode.original_dwi_files")], + ), + ( + vectors_node, + dwi_model_emc_wf, + [("out_rasb", "dwi_model_emc_inputnode.original_rasb_file")], + ), + ( + b0_based_image_transforms, + dwi_model_emc_wf, + [("warped_image", "dwi_model_emc_inputnode.aligned_dwi_files")], + ), + ( + b0_based_vector_transforms, + dwi_model_emc_wf, + [("out_rasb", "dwi_model_emc_inputnode.aligned_vectors")], + ), + ( + b0_median, + dwi_model_emc_wf, + [("out_ref", "dwi_model_emc_inputnode.b0_median")], + ), + ( + eastm_wf, + dwi_model_emc_wf, + [("eastm_outputnode.mask_file", "dwi_model_emc_inputnode.b0_mask")], + ), + ( + vectors_node, + dwi_model_emc_wf, + [("b0_ixs", "dwi_model_emc_inputnode.b0_indices")], + ), + ( + match_transforms_node, + dwi_model_emc_wf, + [("transforms", "dwi_model_emc_inputnode.initial_transforms")], + ), + ( + b0_emc_wf, + dwi_model_emc_wf, + [ + ( + "b0_emc_outputnode.aligned_images", + "dwi_model_emc_inputnode.warped_b0_images", + ) + ], + ), + ( + b0_emc_wf, + meta_outputnode, + [("b0_emc_outputnode.final_template", "final_template")], + ), + ( + dwi_model_emc_wf, + meta_outputnode, + [ + ("dwi_model_emc_outputnode.emc_transforms", "forward_transforms"), + ("dwi_model_emc_outputnode.optimization_data", "optimization_data"), + ("dwi_model_emc_outputnode.cnr_image", "cnr_image"), + ], + ), + ( + b0_emc_wf, + uncorrect_model_images, + [("b0_emc_outputnode.final_template", "fixed_image")], + ), + ( + dwi_model_emc_wf, + uncorrect_model_images, + [("dwi_model_emc_outputnode.emc_transforms", "transform_affine")], + ), + (split_dwis_node, uncorrect_model_images, [("out_files", "moving_image")]), + ( + uncorrect_model_images, + meta_outputnode, + [("warped_image", "noise_free_dwis")], + ), + ( + uncorrect_model_images, + merge_EMC_corrected_dwis_node, + [("warped_image", "in_files")], + ), + ( + merge_EMC_corrected_dwis_node, + meta_outputnode, + [("out_file", "final_emc_4d_series")], + ), + ] + ) return emc_wf