diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index 6c7a4970..4c3f8324 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -21,17 +21,22 @@ # https://www.nipreps.org/community/licensing/ # """A model-based algorithm for the realignment of dMRI data.""" +import gc from pathlib import Path from tempfile import TemporaryDirectory, mkstemp +from dataclasses import dataclass +from typing import Optional, Dict, Union, List, Tuple import nibabel as nb import nitransforms as nt import numpy as np + from nipype.interfaces.ants.registration import Registration from pkg_resources import resource_filename as pkg_fn from tqdm import tqdm from eddymotion.model import ModelFactory +from eddymotion.data.dmri import DWI class EddyMotionEstimator: @@ -42,7 +47,7 @@ def fit( dwdata, *, align_kwargs=None, - models=("b0", ), + models=("b0",), omp_nthreads=None, n_jobs=None, seed=None, @@ -86,9 +91,9 @@ def fit( bmask_img = None if dwdata.brainmask is not None: _, bmask_img = mkstemp(suffix="_bmask.nii.gz") - nb.Nifti1Image( - dwdata.brainmask.astype("uint8"), dwdata.affine, None - ).to_filename(bmask_img) + nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None).to_filename( + bmask_img + ) kwargs["mask"] = dwdata.brainmask kwargs["S0"] = _advanced_clip(dwdata.bzero) @@ -96,20 +101,22 @@ def fit( if "num_threads" not in align_kwargs and omp_nthreads is not None: align_kwargs["num_threads"] = omp_nthreads + aligner = Aligner(dwdata, bmask_img, align_kwargs, models) + n_iter = len(models) for i_iter, model in enumerate(models): - reg_target_type = ( - "dwi" - if model.lower() not in ("b0", "s0", "avg", "average", "mean") - else "b0" - ) index_order = np.arange(len(dwdata)) np.random.shuffle(index_order) - single_model = ( - model.lower() in ("b0", "s0", "avg", "average", "mean") - or model.lower().startswith("full") - ) + aligner.set_model_iter(i_iter) + + single_model = model.lower() in ( + "b0", + "s0", + "avg", + "average", + "mean", + ) or model.lower().startswith("full") dwmodel = None if single_model: @@ -126,14 +133,15 @@ def fit( with TemporaryDirectory() as tmpdir: print(f"Processing in <{tmpdir}>") + with tqdm(total=len(index_order), unit="dwi") as pbar: # run a original-to-synthetic affine registration - for i in index_order: + for b_ix in index_order: pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" + f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{b_ix}>" ) - data_train, data_test = dwdata.logo_split(i, with_b0=True) - grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}" + data_train, data_test = dwdata.logo_split(b_ix, with_b0=True) + grad_str = f"{b_ix}, {data_test[1][:3]}, b={int(data_test[1][3])}" pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") if not single_model: # A true LOGO estimator @@ -151,67 +159,130 @@ def fit( n_jobs=n_jobs, ) - # generate a synthetic dw volume for the test gradient + # predict the gradient predicted = dwmodel.predict(data_test[1]) - - # prepare data for running ANTs - tmpdir = Path(tmpdir) - moving = tmpdir / f"moving{i:05d}.nii.gz" - fixed = tmpdir / f"fixed{i:05d}.nii.gz" - _to_nifti(data_test[0], dwdata.affine, moving) - _to_nifti( - predicted, - dwdata.affine, - fixed, - clip=reg_target_type == "dwi", - ) - pbar.set_description_str( - f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>" - ) - registration = Registration( - terminal_output="file", - from_file=pkg_fn( - "eddymotion", - f"config/dwi-to-{reg_target_type}_level{i_iter}.json", - ), - fixed_image=str(fixed.absolute()), - moving_image=str(moving.absolute()), - **align_kwargs, - ) - if bmask_img: - registration.inputs.fixed_image_masks = ["NULL", bmask_img] - - if dwdata.em_affines and dwdata.em_affines[i] is not None: - mat_file = tmpdir / f"init_{i_iter}_{i:05d}.mat" - dwdata.em_affines[i].to_filename(mat_file, fmt="itk") - registration.inputs.initial_moving_transform = str(mat_file) - - # execute ants command line - result = registration.run(cwd=str(tmpdir)).outputs - - # read output transform - xform = nt.linear.Affine( - nt.io.itk.ITKLinearTransform.from_filename( - result.forward_transforms[0] - ).to_ras(reference=fixed, moving=moving), - ) - # debugging: generate aligned file for testing - xform.apply(moving, reference=fixed).to_filename( - tmpdir / f"aligned{i:05d}_{int(data_test[1][3]):04d}.nii.gz" + f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{b_ix}>" ) + # Initialize the ANTs registration object for the current model iteration + xform = aligner.transform(Path(tmpdir), data_test, b_ix, predicted) + # update - dwdata.set_transform(i, xform.matrix) + dwdata.set_transform(b_ix, xform.matrix) pbar.update() + # free memory + del xform, predicted, data_train, data_test + gc.collect() + return dwdata.em_affines -def _advanced_clip( - data, p_min=35, p_max=99.98, nonnegative=True, dtype="int16", invert=False -): +@dataclass +class Aligner: + """Convenience dataclass that wraps and tracks ANTs registrations for each gradient prediction. + + Attributes + ---------- + dwdata : :obj:`~eddymotion.data.DWI` + The DWI data object. + bmask_img : :obj:`str` + Path to a brain mask image. + align_kwargs : :obj:`dict` + Additional keyword arguments to pass to the ANTs registration call. + models : :obj:`list` of :obj:`str` + List of model names. + """ + + dwdata: DWI + bmask_img: Optional[str] + align_kwargs: Dict + models: Union[List[str], Tuple[str]] + + def set_model_iter(self, i_iter: int) -> None: + """Set the model iteration.""" + self._model_iter = i_iter + + @property + def model(self) -> str: + """Return the model name.""" + return self.models[self._model_iter] + + @property + def reg_target_type(self) -> str: + """Return the registration target type.""" + return ( + "dwi" + if self.models[self._model_iter].lower() not in ("b0", "s0", "avg", "average", "mean") + else "b0" + ) + + def transform( + self, basedir: Path, data_test: np.ndarray, b_ix: int, predicted: np.ndarray + ) -> nt.linear.Affine: + """Run ANTs registration and return the resulting transform. + + Parameters + ---------- + basedir : :obj:`pathlib.Path` + Path to a working directory. + data_test : :obj:`numpy.ndarray` + The test data. + b_ix : :obj:`int` + The index of the current gradient. + predicted : :obj:`numpy.ndarray` + The predicted dw volume for the test gradient. + + """ + + if self.bmask_img: + self.registration.inputs.fixed_image_masks = ["NULL", self.bmask_img] + + # prepare data for running ANTs + moving = basedir / f"moving{b_ix:05d}.nii.gz" + fixed = basedir / f"fixed{b_ix:05d}.nii.gz" + _to_nifti(data_test[0], self.dwdata.affine, moving) + _to_nifti( + predicted, # generate a synthetic dw volume for the test gradient + self.dwdata.affine, + fixed, + clip=self.reg_target_type == "dwi", + ) + + self.registration = Registration( + terminal_output="file", + from_file=pkg_fn( + "eddymotion", + f"config/dwi-to-{self.reg_target_type}_level{self._model_iter}.json", + ), + fixed_image=str(fixed.absolute()), + moving_image=str(moving.absolute()), + **self.align_kwargs, + ) + + if self.dwdata.em_affines and self.dwdata.em_affines[b_ix] is not None: + mat_file = basedir / f"init_{self._model_iter}_{b_ix:05d}.mat" + self.dwdata.em_affines[b_ix].to_filename(mat_file, fmt="itk") + self.registration.inputs.initial_moving_transform = str(mat_file) + + # read output transform + xform = nt.linear.Affine( + nt.io.itk.ITKLinearTransform.from_filename( + self.registration.run(cwd=str(basedir)).outputs.forward_transforms[0] + ).to_ras(reference=fixed, moving=moving), + ) + # debugging: generate aligned file for testing + xform.apply(moving, reference=fixed).to_filename( + basedir / f"aligned{b_ix:05d}_{int(data_test[1][3]):04d}.nii.gz" + ) + + return xform + + +def _advanced_clip(data, p_min=35, p_max=99.98, nonnegative=True, dtype="int16", invert=False): + r""" Remove outliers at both ends of the intensity distribution and fit into a given dtype. This interface tries to emulate ANTs workflows' massaging that truncate images into @@ -232,6 +303,9 @@ def _advanced_clip( # Calculate stats on denoised version, to preempt outliers from biasing denoised = ndimage.median_filter(data, footprint=ball(3)) + if len(denoised[denoised > 0]) == 0: + return data + a_min = np.percentile(denoised[denoised > 0] if nonnegative else denoised, p_min) a_max = np.percentile(denoised[denoised > 0] if nonnegative else denoised, p_max) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 65b76dd6..304ed01f 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -25,6 +25,7 @@ from joblib import Parallel, delayed import numpy as np from dipy.core.gradients import gradient_table +from importlib import import_module def _exec_fit(model, data, chunk=None): @@ -59,6 +60,7 @@ def init(gtab, model="DTI", **kwargs): An model object compliant with DIPY's interface. """ + if model.lower() in ("s0", "b0"): return TrivialB0Model(gtab=gtab, S0=kwargs.pop("S0")) @@ -84,31 +86,38 @@ class BaseModel: """ - __slots__ = ( - "_model", - "_mask", - "_S0", - "_b_max", - "_models", - "_datashape", - ) + __slots__ = ("_model", "_mask", "_S0", "_b_max", "_models", "_datashape", "_n_models") _modelargs = tuple() def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): """Base initialization.""" - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip(S0.astype("float32") / S0.max(), a_min=1e-5, a_max=1.0,) - # Setup brain mask self._mask = mask + if "mask" in kwargs: + self._mask = kwargs.pop("mask") if mask is None and S0 is not None: self._mask = self._S0 > np.percentile(self._S0, 35) + # Setup B0 map + if "S0" in kwargs: + S0 = kwargs.pop("S0") + if S0 is not None: + self._S0 = np.clip( + S0.astype("float32") / S0.max(), + a_min=1e-5, + a_max=1.0, + ) + # Select voxels within mask or just unravel 3D if no mask + self._S0 = ( + np.ma.masked_array(self._S0, mask=np.broadcast_to(self._mask, self._S0.shape)).data + if self._mask is not None + else self._S0.reshape(-1, self._S0.shape[-1]) + ) + # Cap b-values, if requested - self._b_max = None + if "b_max" in kwargs: + b_max = kwargs.pop("b_max") if b_max and b_max > 1000: # Saturate b-values at b_max, since signal stops dropping gtab[-1, gtab[-1] > b_max] = b_max @@ -117,6 +126,8 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): # data = data[..., bval_mask] # gtab = gtab[:, bval_mask] self._b_max = b_max + else: + self._b_max = None kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} @@ -124,12 +135,8 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): if not model_str: raise TypeError("No model defined") - from importlib import import_module - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr( - import_module(module_name), class_name - )(_rasb2dipy(gtab), **kwargs) + self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) def fit(self, data, n_jobs=None, **kwargs): """Fit the model chunk-by-chunk asynchronously""" @@ -137,10 +144,17 @@ def fit(self, data, n_jobs=None, **kwargs): self._datashape = data.shape + # Add fourth axis to mask if missing + mask = ( + self._mask[..., None] + if self._mask is not None and self._mask.ndim == 3 + else self._mask + ) + # Select voxels within mask or just unravel 3D if no mask data = ( - data[self._mask, ...] - if self._mask is not None + np.ma.masked_array(data, mask=np.broadcast_to(mask, data.shape)).data + if mask is not None else data.reshape(-1, data.shape[-1]) ) @@ -149,22 +163,25 @@ def fit(self, data, n_jobs=None, **kwargs): self._model, _ = _exec_fit(self._model, data) return - # Split data into chunks of group of slices - data_chunks = np.array_split(data, n_jobs) - self._models = [None] * n_jobs # Parallelize process with joblib + # Split data into chunks of group of slices with Parallel(n_jobs=n_jobs) as executor: results = executor( delayed(_exec_fit)(self._model, dchunk, i) - for i, dchunk in enumerate(data_chunks) + for i, dchunk in enumerate(np.array_split(data, n_jobs)) ) - for submodel, index in results: - self._models[index] = submodel + if results: + for submodel, index in results: + self._models[index] = submodel + else: + raise RuntimeError("No results from parallel execution across data chunks.") self._model = None # Preempt further actions on the model + self._n_models = len(self._models) if self._model is None and self._models else 1 + def predict(self, gradient, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" if self._b_max is not None: @@ -172,44 +189,40 @@ def predict(self, gradient, **kwargs): gradient = _rasb2dipy(gradient) - S0 = None - if self._S0 is not None: - S0 = ( - self._S0[self._mask, ...] - if self._mask is not None - else self._S0.reshape(-1, self._S0.shape[-1]) - ) - - n_models = len(self._models) if self._model is None and self._models else 1 - - if n_models == 1: + if self._n_models == 1: + S0 = self._S0 predicted, _ = _exec_predict(self._model, gradient, S0=S0, **kwargs) else: S0 = ( - np.array_split(S0, n_models) if S0 is not None - else [None] * n_models + np.array_split(self._S0, self._n_models) + if self._S0 is not None + else [None] * self._n_models ) - predicted = [None] * n_models + predicted = [None] * self._n_models # Parallelize process with joblib - with Parallel(n_jobs=n_models) as executor: + with Parallel(n_jobs=self._n_models) as executor: results = executor( delayed(_exec_predict)(model, gradient, S0=S0[i], chunk=i, **kwargs) for i, model in enumerate(self._models) ) - for subprediction, index in results: - predicted[index] = subprediction - - predicted = np.hstack(predicted) - - if self._mask is not None: - retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted - else: - retval = predicted.reshape(self._datashape[:-1]) - return retval + if results: + predicted = np.vstack([r[0] for r in results]) + if self._mask is not None: + retval = np.zeros_like(self._mask, dtype="float32") + if self._mask.ndim == 3: + mask = self._mask.reshape(-1) + retval = retval.reshape(-1) + retval[mask] = predicted.reshape(-1) + retval = retval.reshape(self._datashape[:-1]) + else: + retval = predicted.reshape(self._datashape[:-1]) + return retval + + else: + raise RuntimeError("No results from parallel execution across data chunks.") class TrivialB0Model: diff --git a/test/test_dmri.py b/test/test_dmri.py index f907b5c1..92f8ec18 100644 --- a/test/test_dmri.py +++ b/test/test_dmri.py @@ -54,7 +54,11 @@ def test_load(datadir, tmp_path): assert np.allclose(dwi_h5.gradients, dwi_from_nifti1.gradients) # Try loading NIfTI + b-vecs/vals - dwi_from_nifti2 = load(dwi_nifti_path, bvec_file=bvecs_path, bval_file=bvals_path,) + dwi_from_nifti2 = load( + dwi_nifti_path, + bvec_file=bvecs_path, + bval_file=bvals_path, + ) assert np.allclose(dwi_h5.dataobj, dwi_from_nifti2.dataobj) assert np.allclose(dwi_h5.bzero, dwi_from_nifti2.bzero) diff --git a/test/test_estimator.py b/test/test_estimator.py index b592fe57..b1df32c9 100644 --- a/test/test_estimator.py +++ b/test/test_estimator.py @@ -69,12 +69,10 @@ def test_ANTs_config_b0(datadir, tmp_path, r_x, r_y, r_z, t_x, t_y, t_z): result = registration.run(cwd=str(tmp_path)).outputs xform = nt.linear.Affine( - nt.io.itk.ITKLinearTransform.from_filename( - result.forward_transforms[0] - ).to_ras(), + nt.io.itk.ITKLinearTransform.from_filename(result.forward_transforms[0]).to_ras(), reference=b0nii, ) coords = xfm.reference.ndcoords.T rms = np.sqrt(((xfm.map(coords) - xform.map(coords)) ** 2).sum(1)).mean() - assert rms < 0.8 + assert rms < 0.8 diff --git a/test/test_integration.py b/test/test_integration.py index a7ac01d6..095bccff 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -39,27 +39,13 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): # Generate a list of large-yet-plausible bulk-head motion. xfms = nt.linear.LinearTransformsMapping( [ - nb.affines.from_matvec( - nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2) - ), - nb.affines.from_matvec( - nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2) - ), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.03, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.005), (0.8, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.02), (0.4, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(x=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.02, z=0.002), (0.0, 0.2, 0.2)), + nb.affines.from_matvec(nb.eulerangles.euler2mat(y=-0.01, z=0.002), (0.0, 0.4, 0.2)), ], reference=b0nii, ) @@ -81,9 +67,7 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path): ) estimator = EddyMotionEstimator() - em_affines = estimator.fit( - dwdata=dwi_motion, models=("b0", ), align_kwargs=None, seed=None - ) + em_affines = estimator.fit(dwdata=dwi_motion, models=("b0",), align_kwargs=None, seed=None) # Uncomment to see the realigned dataset # nt.linear.LinearTransformsMapping( diff --git a/test/test_model.py b/test/test_model.py index 22c9b75f..cb86aa7b 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -64,11 +64,13 @@ def test_average_model(): tmodel_mean = model.AverageDWModel(gtab=gtab, bias=False, stat="mean") tmodel_median = model.AverageDWModel(gtab=gtab, bias=False, stat="median") - tmodel_1000 = model.AverageDWModel( - gtab=gtab, bias=False, th_high=1000, th_low=900 - ) + tmodel_1000 = model.AverageDWModel(gtab=gtab, bias=False, th_high=1000, th_low=900) tmodel_2000 = model.AverageDWModel( - gtab=gtab, bias=False, th_high=2000, th_low=900, stat="mean", + gtab=gtab, + bias=False, + th_high=2000, + th_low=900, + stat="mean", ) # Verify that fit function returns nothing