From bde69edce1699279f0a6d96509f0a3cf542395b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 30 Mar 2024 10:10:45 -0400 Subject: [PATCH] ENH: Simplify `eddymotion.estimator.EddyMotionEstimator.fit` Simplify `eddymotion.estimator.EddyMotionEstimator.fit`: break down different parts into separate methods. Fixes: ``` src/eddymotion/estimator.py:43:9: C901 `fit` is too complex (23 > 10) ``` raised by `ruff`. Remove the `C901` error exception rule from the `ruff` linter whitelist. --- pyproject.toml | 1 - src/eddymotion/estimator.py | 318 +++++++++++++++++++++++++++--------- 2 files changed, 239 insertions(+), 80 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 216ef24b..a7faa601 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,6 @@ inline-quotes = "double" "*/__init__.py" = ["F401"] "docs/conf.py" = ["E265"] "/^\\s*\\.\\. _.*?: http/" = ["E501"] -"src/eddymotion/estimator.py" = ["C901"] [tool.ruff.format] quote-style = "double" diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index e99b67e7..ead5f7fe 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -21,6 +21,7 @@ # https://www.nipreps.org/community/licensing/ # """A model-based algorithm for the realignment of dMRI data.""" + from collections import namedtuple from pathlib import Path from tempfile import TemporaryDirectory, mkstemp @@ -75,9 +76,7 @@ def fit( Number of parallel jobs. seed : :obj:`int` or :obj:`bool` Seed the random number generator (necessary when we want deterministic - estimation). If an integer, the value is used to initialize the - generator; if ``True``, the arbitrary value of ``20210324`` is used - to initialize it. + estimation). See :func:`_sort_dwdata_indices`. Return ------ affines : :obj:`list` of :obj:`numpy.ndarray` @@ -85,13 +84,10 @@ def fit( parameters of the deformations caused by head-motion and eddy-currents. """ - align_kwargs = align_kwargs or {} - _seed = None - if seed or seed == 0: - _seed = 20210324 if seed is True else seed + align_kwargs = align_kwargs or {} - rng = np.random.default_rng(_seed) + index_order = _sort_dwdata_indices(seed, len(dwdata)) if "num_threads" not in align_kwargs and omp_nthreads is not None: align_kwargs["num_threads"] = omp_nthreads @@ -103,28 +99,9 @@ def fit( ) # When downsampling these need to be set per-level - bmask_img = None - if dwdata.brainmask is not None: - _, bmask_img = mkstemp(suffix="_bmask.nii.gz") - nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None).to_filename( - bmask_img - ) - kwargs["mask"] = dwdata.brainmask - - if hasattr(dwdata, "bzero") and dwdata.bzero is not None: - kwargs["S0"] = _advanced_clip(dwdata.bzero) - - if hasattr(dwdata, "gradients"): - kwargs["gtab"] = dwdata.gradients + bmask_img = _prepare_brainmask_data(dwdata.brainmask, dwdata.affine) - if hasattr(dwdata, "frame_time"): - kwargs["timepoints"] = dwdata.frame_time - - if hasattr(dwdata, "total_duration"): - kwargs["xlim"] = dwdata.total_duration - - index_order = np.arange(len(dwdata)) - rng.shuffle(index_order) + _prepare_kwargs(dwdata, kwargs) single_model = model.lower() in ( "b0", @@ -148,6 +125,7 @@ def fit( with TemporaryDirectory() as tmp_dir: print(f"Processing in <{tmp_dir}>") + ptmp_dir = Path(tmp_dir) with tqdm(total=len(index_order), unit="dwi") as pbar: # run a original-to-synthetic affine registration for i in index_order: @@ -178,62 +156,28 @@ def fit( predicted = dwmodel.predict(data_test[1]) # prepare data for running ANTs - tmp_dir = Path(tmp_dir) - moving = tmp_dir / f"moving{i:05d}.nii.gz" - fixed = tmp_dir / f"fixed{i:05d}.nii.gz" - _to_nifti(data_test[0], dwdata.affine, moving) - _to_nifti( - predicted, - dwdata.affine, - fixed, - clip=reg_target_type == "dwi", + fixed, moving = _prepare_registration_data( + data_test[0], predicted, dwdata.affine, i, ptmp_dir, reg_target_type ) pbar.set_description_str( f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>" ) - registration = Registration( - terminal_output="file", - from_file=pkg_fn( - "eddymotion", - f"config/dwi-to-{reg_target_type}_level{i_iter}.json", - ), - fixed_image=str(fixed.absolute()), - moving_image=str(moving.absolute()), - **align_kwargs, - ) - if bmask_img: - registration.inputs.fixed_image_masks = ["NULL", bmask_img] - - if dwdata.em_affines is not None and np.any(dwdata.em_affines[i, ...]): - reference = namedtuple("ImageGrid", ("shape", "affine"))( - shape=dwdata.dataobj.shape[:3], affine=dwdata.affine - ) - # create a nitransforms object - if dwdata.fieldmap: - # compose fieldmap into transform - raise NotImplementedError - else: - initial_xform = Affine( - matrix=dwdata.em_affines[i], reference=reference - ) - mat_file = tmp_dir / f"init_{i_iter}_{i:05d}.mat" - initial_xform.to_filename(mat_file, fmt="itk") - registration.inputs.initial_moving_transform = str(mat_file) - - # execute ants command line - result = registration.run(cwd=str(tmp_dir)).outputs - - # read output transform - xform = nt.linear.Affine( - nt.io.itk.ITKLinearTransform.from_filename( - result.forward_transforms[0] - ).to_ras(reference=fixed, moving=moving), - ) - # debugging: generate aligned file for testing - xform.apply(moving, reference=fixed).to_filename( - tmp_dir / f"aligned{i:05d}_{int(data_test[1][3]):04d}.nii.gz" + xform = _run_registration( + fixed, + moving, + bmask_img, + dwdata.em_affines, + dwdata.affine, + dwdata.dataobj.shape[:3], + data_test[1][3], + dwdata.fieldmap, + i_iter, + i, + ptmp_dir, + reg_target_type, + align_kwargs, ) # update @@ -294,3 +238,219 @@ def _to_nifti(data, affine, filename, clip=True): nii.header.set_sform(affine, code=1) nii.header.set_qform(affine, code=1) nii.to_filename(filename) + + +def _sort_dwdata_indices(seed, dwi_vol_count): + """Sort the DWI data volume indices. + + Parameters + ---------- + seed : :obj:`int` or :obj:`bool` + Seed the random number generator. If an integer, the value is used to initialize the + generator; if ``True``, the arbitrary value of ``20210324`` is used to initialize it. + dwi_vol_count : :obj:`int` + Number of DWI volumes. + + Returns + ------- + index_order : :obj:`numpy.ndarray` + Index order. + """ + + _seed = None + if seed or seed == 0: + _seed = 20210324 if seed is True else seed + + rng = np.random.default_rng(_seed) + + index_order = np.arange(dwi_vol_count) + rng.shuffle(index_order) + + return index_order + + +def _prepare_brainmask_data(brainmask, affine): + """Prepare the brainmask data: save the data to disk. + + Parameters + ---------- + brainmask : :obj:`numpy.ndarray` + Brainmask data. + affine : :obj:`numpy.ndarray` + Affine transformation matrix. + + Returns + ------- + bmask_img : :class:`~nibabel.nifti1.Nifti1Image` + Brainmask image. + """ + + bmask_img = None + if brainmask is not None: + _, bmask_img = mkstemp(suffix="_bmask.nii.gz") + nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img) + return bmask_img + + +def _prepare_kwargs(dwdata, kwargs): + """Prepare the keyword arguments depending on the DWI data: add attributes corresponding to + the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data + properties. + + Modifies kwargs in-place. + + Parameters + ---------- + dwdata : :class:`eddymotion.data.dmri.DWI` + DWI data object. + kwargs : :obj:`dict` + Keyword arguments. + """ + + if dwdata.brainmask is not None: + kwargs["mask"] = dwdata.brainmask + + if hasattr(dwdata, "bzero") and dwdata.bzero is not None: + kwargs["S0"] = _advanced_clip(dwdata.bzero) + + if hasattr(dwdata, "gradients"): + kwargs["gtab"] = dwdata.gradients + + if hasattr(dwdata, "frame_time"): + kwargs["timepoints"] = dwdata.frame_time + + if hasattr(dwdata, "total_duration"): + kwargs["xlim"] = dwdata.total_duration + + +def _prepare_registration_data(dwframe, predicted, affine, vol_idx, dirname, reg_target_type): + """Prepare the registration data: save the fixed and moving images to disk. + + Parameters + ---------- + dwframe : :obj:`numpy.ndarray` + DWI data object. + predicted : :obj:`numpy.ndarray` + Predicted data. + affine : :obj:`numpy.ndarray` + Affine transformation matrix. + vol_idx : :obj:`int + DWI volume index. + dirname : :obj:`Path` + Directory name where the data is saved. + reg_target_type : :obj:`str` + Target registration type. + + Returns + ------- + fixed : :obj:`Path` + Fixed image filename. + moving : :obj:`Path` + Moving image filename. + """ + + moving = dirname / f"moving{vol_idx:05d}.nii.gz" + fixed = dirname / f"fixed{vol_idx:05d}.nii.gz" + _to_nifti(dwframe, affine, moving) + _to_nifti( + predicted, + affine, + fixed, + clip=reg_target_type == "dwi", + ) + return fixed, moving + + +def _run_registration( + fixed, + moving, + bmask_img, + em_affines, + affine, + shape, + bval, + fieldmap, + i_iter, + vol_idx, + dirname, + reg_target_type, + align_kwargs, +): + """Register the moving image to the fixed image. + + Parameters + ---------- + fixed : :obj:`Path` + Fixed image filename. + moving : :obj:`Path` + Moving image filename. + bmask_img : :class:`~nibabel.nifti1.Nifti1Image` + Brainmask image. + em_affines : :obj:`numpy.ndarray` + Estimated eddy motion affine transformation matrices. + affine : :obj:`numpy.ndarray` + Affine transformation matrix. + shape : :obj:`tuple` + Shape of the DWI frame. + bval : :obj:`int` + b-value of the corresponding DWI volume. + fieldmap : :class:`~nibabel.nifti1.Nifti1Image` + Fieldmap. + i_iter : :obj:`int` + Iteration number. + vol_idx : :obj:`int` + DWI frame index. + dirname : :obj:`Path` + Directory name where the transformation is saved. + reg_target_type : :obj:`str` + Target registration type. + align_kwargs : :obj:`dict` + Parameters to configure the image registration process. + + Returns + ------- + xform : :class:`~nitransforms.linear.Affine` + Registration transformation. + """ + + registration = Registration( + terminal_output="file", + from_file=pkg_fn( + "eddymotion", + f"config/dwi-to-{reg_target_type}_level{i_iter}.json", + ), + fixed_image=str(fixed.absolute()), + moving_image=str(moving.absolute()), + **align_kwargs, + ) + if bmask_img: + registration.inputs.fixed_image_masks = ["NULL", bmask_img] + + if em_affines is not None and np.any(em_affines[vol_idx, ...]): + reference = namedtuple("ImageGrid", ("shape", "affine"))(shape=shape, affine=affine) + + # create a nitransforms object + if fieldmap: + # compose fieldmap into transform + raise NotImplementedError + else: + initial_xform = Affine(matrix=em_affines[vol_idx], reference=reference) + mat_file = dirname / f"init_{i_iter}_{vol_idx:05d}.mat" + initial_xform.to_filename(mat_file, fmt="itk") + registration.inputs.initial_moving_transform = str(mat_file) + + # execute ants command line + result = registration.run(cwd=str(dirname)).outputs + + # read output transform + xform = nt.linear.Affine( + nt.io.itk.ITKLinearTransform.from_filename(result.forward_transforms[0]).to_ras( + reference=fixed, moving=moving + ), + ) + # debugging: generate aligned file for testing + xform.apply(moving, reference=fixed).to_filename( + dirname / f"aligned{vol_idx:05d}_{int(bval):04d}.nii.gz" + ) + + return xform