diff --git a/dmriprep/config/emc_coarse_Affine.json b/dmriprep/config/emc_coarse_Affine.json new file mode 100644 index 00000000..f7663871 --- /dev/null +++ b/dmriprep/config/emc_coarse_Affine.json @@ -0,0 +1,8 @@ +{ + "level_iters": [1000, 100], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} \ No newline at end of file diff --git a/dmriprep/config/emc_coarse_Rigid.json b/dmriprep/config/emc_coarse_Rigid.json new file mode 100644 index 00000000..5d23fba6 --- /dev/null +++ b/dmriprep/config/emc_coarse_Rigid.json @@ -0,0 +1,8 @@ +{ + "level_iters": [100, 100], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} diff --git a/dmriprep/config/emc_precise_Affine.json b/dmriprep/config/emc_precise_Affine.json new file mode 100644 index 00000000..99fbb1e2 --- /dev/null +++ b/dmriprep/config/emc_precise_Affine.json @@ -0,0 +1,8 @@ +{ + "level_iters": [1000, 1000], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} \ No newline at end of file diff --git a/dmriprep/config/emc_precise_Rigid.json b/dmriprep/config/emc_precise_Rigid.json new file mode 100644 index 00000000..e62cd12a --- /dev/null +++ b/dmriprep/config/emc_precise_Rigid.json @@ -0,0 +1,8 @@ +{ + "level_iters": [1000, 1000], + "metric": "MI", + "sigmas": [8.0, 2.0], + "factors": [2, 1], + "sampling_prop": 0.15, + "nbins": 48 +} diff --git a/dmriprep/interfaces/registration.py b/dmriprep/interfaces/registration.py new file mode 100644 index 00000000..44f3256f --- /dev/null +++ b/dmriprep/interfaces/registration.py @@ -0,0 +1,147 @@ +"""Register tools interfaces.""" +import numpy as np +import nibabel as nb +import dmriprep +from nipype import logging +from pathlib import Path +from nipype.utils.filemanip import fname_presuffix +from nipype.interfaces.base import ( + traits, + TraitedSpec, + BaseInterfaceInputSpec, + InputMultiObject, + SimpleInterface, + File, +) + + +LOGGER = logging.getLogger("nipype.interface") +REG_TYPES = ("c_of_mass", "translation", "rigid", "affine") + + +class _ApplyAffineInputSpec(BaseInterfaceInputSpec): + moving_image = File( + exists=True, mandatory=True, desc="image that will be resampled into the reference" + ) + fixed_image = File( + exists=True, mandatory=True, desc="image defining the reference coordinate system" + ) + transform_affine = InputMultiObject( + File(exists=True), mandatory=True, desc="transformation affine" + ) + invert_transform = traits.Bool(False, usedefault=True) + + +class _ApplyAffineOutputSpec(TraitedSpec): + warped_image = File(exists=True, desc="Outputs warped image") + + +class ApplyAffine(SimpleInterface): + """ + Interface to apply an affine transformation to an image. + """ + + input_spec = _ApplyAffineInputSpec + output_spec = _ApplyAffineOutputSpec + + def _run_interface(self, runtime): + from dmriprep.utils.registration import apply_affine + + warped_image_nifti = apply_affine( + nb.load(self.inputs.moving_image), + nb.load(self.inputs.fixed_image), + np.load(self.inputs.transform_affine[0]), + self.inputs.invert_transform, + ) + cwd = Path(runtime.cwd).absolute() + warped_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_warped.nii.gz", + newpath=str(cwd), + ) + + warped_image_nifti.to_filename(warped_file) + + self._results["warped_image"] = warped_file + return runtime + + +class _RegisterInputSpec(BaseInterfaceInputSpec): + moving_image = File( + exists=True, mandatory=True, desc="image to apply transformation from" + ) + fixed_image = File( + exists=True, mandatory=True, desc="image to apply transformation to" + ) + nbins = traits.Int(default_value=32, usedefault=True) + sampling_prop = traits.Float(default_value=1, usedefault=True) + metric = traits.Str(default_value="MI", usedefault=True) + level_iters = traits.List( + trait=traits.Any(), value=[10000, 1000, 100], usedefault=True + ) + sigmas = traits.List(trait=traits.Any(), value=[5.0, 2.5, 0.0], usedefault=True) + factors = traits.List(trait=traits.Any(), value=[4, 2, 1], usedefault=True) + params0 = traits.ArrayOrNone(value=None, usedefault=True) + pipeline = traits.List( + traits.Enum(*REG_TYPES), + value=list(REG_TYPES), + usedefault=True, + ) + + +class _RegisterOutputSpec(TraitedSpec): + forward_transforms = traits.List( + File(exists=True), desc="List of output transforms for forward registration" + ) + warped_image = File(exists=True, desc="Outputs warped image") + + +class Register(SimpleInterface): + """ + Interface to perform affine registration. + """ + + input_spec = _RegisterInputSpec + output_spec = _RegisterOutputSpec + + def _run_interface(self, runtime): + from dmriprep.utils.registration import affine_registration + + pipeline = [ + getattr(dmriprep.utils.register, i) + for i in self.inputs.pipeline + if i in REG_TYPES + ] + + warped_image_nifti, forward_transform_mat = affine_registration( + nb.load(self.inputs.moving_image), + nb.load(self.inputs.fixed_image), + self.inputs.nbins, + self.inputs.sampling_prop, + self.inputs.metric, + pipeline, + self.inputs.level_iters, + self.inputs.sigmas, + self.inputs.factors, + self.inputs.params0, + ) + cwd = Path(runtime.cwd).absolute() + warped_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_warped.nii.gz", + newpath=str(cwd), + ) + forward_transform_file = fname_presuffix( + self.inputs.moving_image, + use_ext=False, + suffix="_forward_transform.npy", + newpath=str(cwd), + ) + warped_image_nifti.to_filename(warped_file) + + np.save(forward_transform_file, forward_transform_mat) + self._results["warped_image"] = warped_file + self._results["forward_transforms"] = [forward_transform_file] + return runtime diff --git a/dmriprep/utils/registration.py b/dmriprep/utils/registration.py new file mode 100644 index 00000000..56b47084 --- /dev/null +++ b/dmriprep/utils/registration.py @@ -0,0 +1,170 @@ +""" +Linear affine registration tools for motion correction. +""" +import numpy as np +import nibabel as nb +from dipy.align.metrics import CCMetric, EMMetric, SSDMetric +from dipy.align.imaffine import ( + transform_centers_of_mass, + AffineMap, + MutualInformationMetric, + AffineRegistration, +) +from dipy.align.transforms import ( + TranslationTransform3D, + RigidTransform3D, + AffineTransform3D, +) +from nipype.utils.filemanip import fname_presuffix + +syn_metric_dict = {"CC": CCMetric, "EM": EMMetric, "SSD": SSDMetric} + +__all__ = [ + "c_of_mass", + "translation", + "rigid", + "affine", + "affine_registration", +] + + +def apply_affine(moving, static, transform_affine, invert=False): + """Apply an affine to transform an image from one space to another. + + Parameters + ---------- + moving : array + The image to be resampled + + static : array + + Returns + ------- + warped_img : the moving array warped into the static array's space. + + """ + affine_map = AffineMap( + transform_affine, static.shape, static.affine, moving.shape, moving.affine + ) + if invert is True: + warped_arr = affine_map.transform_inverse(np.asarray(moving.dataobj)) + else: + warped_arr = affine_map.transform(np.asarray(moving.dataobj)) + + return nb.Nifti1Image(warped_arr, static.affine) + + +def average_affines(transforms): + affine_list = [np.load(aff) for aff in transforms] + average_affine_file = fname_presuffix( + transforms[0], use_ext=False, suffix="_average.npy" + ) + np.save(average_affine_file, np.mean(affine_list, axis=0)) + return average_affine_file + + +# Affine registration pipeline: +affine_metric_dict = {"MI": MutualInformationMetric, "CC": CCMetric} + + +def c_of_mass( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = transform_centers_of_mass(static, static_affine, moving, moving_affine) + transformed = transform.transform(moving) + return transformed, transform.affine + + +def translation( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = TranslationTransform3D() + translation = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return translation.transform(moving), translation.affine + + +def rigid( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = RigidTransform3D() + rigid = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + return rigid.transform(moving), rigid.affine + + +def affine( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = AffineTransform3D() + affine = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return affine.transform(moving), affine.affine + + +def affine_registration( + moving, + static, + nbins, + sampling_prop, + metric, + pipeline, + level_iters, + sigmas, + factors, + params0, +): + + """ + Find the affine transformation between two 3D images. + + Parameters + ---------- + + """ + # Define the Affine registration object we'll use with the chosen metric: + use_metric = affine_metric_dict[metric](nbins, sampling_prop) + affreg = AffineRegistration( + metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors + ) + + if not params0: + starting_affine = np.eye(4) + else: + starting_affine = params0 + + # Go through the selected transformation: + for func in pipeline: + transformed, starting_affine = func( + np.asarray(moving.dataobj), + np.asarray(static.dataobj), + static.affine, + moving.affine, + affreg, + starting_affine, + params0, + ) + return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine diff --git a/dmriprep/utils/tests/test_registration.py b/dmriprep/utils/tests/test_registration.py new file mode 100644 index 00000000..926ebd12 --- /dev/null +++ b/dmriprep/utils/tests/test_registration.py @@ -0,0 +1,66 @@ +import pytest +import numpy as np +import numpy.testing as npt +import nibabel as nb +import dipy.data as dpd +from dmriprep.utils.registration import affine_registration + + +def setup_module(): + global subset_b0, subset_dwi_data, subset_t2, subset_b0_img, \ + subset_t2_img, gtab, hardi_affine, MNI_T2_affine + MNI_T2 = dpd.read_mni_template() + hardi_img, gtab = dpd.read_stanford_hardi() + MNI_T2_data = MNI_T2.get_fdata() + MNI_T2_affine = MNI_T2.affine + hardi_data = hardi_img.get_fdata() + hardi_affine = hardi_img.affine + b0 = hardi_data[..., gtab.b0s_mask] + mean_b0 = np.mean(b0, -1) + + # Select some arbitrary chunks of data so this goes quicker + subset_b0 = mean_b0[40:50, 40:50, 40:50] + subset_dwi_data = nb.Nifti1Image(hardi_data[40:50, 40:50, 40:50], hardi_affine) + subset_t2 = MNI_T2_data[40:60, 40:60, 40:60] + subset_b0_img = nb.Nifti1Image(subset_b0, hardi_affine) + subset_t2_img = nb.Nifti1Image(subset_t2, MNI_T2_affine) + + +@pytest.mark.parametrize("nbins", [32, 22]) +@pytest.mark.parametrize("sampling_prop", [1, 2]) +@pytest.mark.parametrize("metric", ["MI", "CC"]) +@pytest.mark.parametrize("level_iters", [[10000, 100], [1]]) +@pytest.mark.parametrize("sigmas", [[5.0, 2.5], [0.0]]) +@pytest.mark.parametrize("factors", [[4, 2], [1]]) +@pytest.mark.parametrize("params0", [np.eye(4), None]) +@pytest.mark.parametrize("pipeline", [["rigid"], ["affine"], ["rigid", "affine"]]) +def test_affine_registration( + nbins, sampling_prop, metric, level_iters, sigmas, factors, params0, pipeline +): + moving = subset_b0 + static = subset_b0 + moving_affine = static_affine = np.eye(4) + xformed, affine = affine_registration(moving, static) + # We don't ask for much: + npt.assert_almost_equal(affine[:3, :3], np.eye(3), decimal=1) + + with pytest.raises(ValueError): + # For array input, must provide affines: + xformed, affine = affine_registration(moving, static) + + # If providing nifti image objects, don't need to provide affines: + moving_img = nb.Nifti1Image(moving, moving_affine) + static_img = nb.Nifti1Image(static, static_affine) + xformed, affine = affine_registration( + moving_img, + static_img, + nbins, + sampling_prop, + metric, + pipeline, + level_iters, + sigmas, + factors, + params0, + ) + npt.assert_almost_equal(affine[:3, :3], np.eye(3), decimal=1)