Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Fix masking add composability #114

Closed
wants to merge 11 commits into from
208 changes: 141 additions & 67 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@
# https://www.nipreps.org/community/licensing/
#
"""A model-based algorithm for the realignment of dMRI data."""
import gc
from pathlib import Path
from tempfile import TemporaryDirectory, mkstemp
from dataclasses import dataclass
from typing import Optional, Dict, Union, List, Tuple

import nibabel as nb
import nitransforms as nt
import numpy as np

from nipype.interfaces.ants.registration import Registration
from pkg_resources import resource_filename as pkg_fn
from tqdm import tqdm

from eddymotion.model import ModelFactory
from eddymotion.data.dmri import DWI


class EddyMotionEstimator:
Expand All @@ -42,7 +47,7 @@ def fit(
dwdata,
*,
align_kwargs=None,
models=("b0", ),
models=("b0",),
omp_nthreads=None,
n_jobs=None,
seed=None,
Expand Down Expand Up @@ -86,30 +91,32 @@ def fit(
bmask_img = None
if dwdata.brainmask is not None:
_, bmask_img = mkstemp(suffix="_bmask.nii.gz")
nb.Nifti1Image(
dwdata.brainmask.astype("uint8"), dwdata.affine, None
).to_filename(bmask_img)
nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None).to_filename(
bmask_img
)
kwargs["mask"] = dwdata.brainmask

kwargs["S0"] = _advanced_clip(dwdata.bzero)

if "num_threads" not in align_kwargs and omp_nthreads is not None:
align_kwargs["num_threads"] = omp_nthreads

aligner = Aligner(dwdata, bmask_img, align_kwargs, models)

n_iter = len(models)
for i_iter, model in enumerate(models):
reg_target_type = (
"dwi"
if model.lower() not in ("b0", "s0", "avg", "average", "mean")
else "b0"
)
index_order = np.arange(len(dwdata))
np.random.shuffle(index_order)

single_model = (
model.lower() in ("b0", "s0", "avg", "average", "mean")
or model.lower().startswith("full")
)
aligner.set_model_iter(i_iter)

single_model = model.lower() in (
"b0",
"s0",
"avg",
"average",
"mean",
) or model.lower().startswith("full")

dwmodel = None
if single_model:
Expand All @@ -126,14 +133,15 @@ def fit(

with TemporaryDirectory() as tmpdir:
print(f"Processing in <{tmpdir}>")

with tqdm(total=len(index_order), unit="dwi") as pbar:
# run a original-to-synthetic affine registration
for i in index_order:
for b_ix in index_order:
pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{b_ix}>"
)
data_train, data_test = dwdata.logo_split(i, with_b0=True)
grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
data_train, data_test = dwdata.logo_split(b_ix, with_b0=True)
grad_str = f"{b_ix}, {data_test[1][:3]}, b={int(data_test[1][3])}"
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")

if not single_model: # A true LOGO estimator
Expand All @@ -151,67 +159,130 @@ def fit(
n_jobs=n_jobs,
)

# generate a synthetic dw volume for the test gradient
# predict the gradient
predicted = dwmodel.predict(data_test[1])

# prepare data for running ANTs
tmpdir = Path(tmpdir)
moving = tmpdir / f"moving{i:05d}.nii.gz"
fixed = tmpdir / f"fixed{i:05d}.nii.gz"
_to_nifti(data_test[0], dwdata.affine, moving)
_to_nifti(
predicted,
dwdata.affine,
fixed,
clip=reg_target_type == "dwi",
)

pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>"
)
registration = Registration(
terminal_output="file",
from_file=pkg_fn(
"eddymotion",
f"config/dwi-to-{reg_target_type}_level{i_iter}.json",
),
fixed_image=str(fixed.absolute()),
moving_image=str(moving.absolute()),
**align_kwargs,
)
if bmask_img:
registration.inputs.fixed_image_masks = ["NULL", bmask_img]

if dwdata.em_affines and dwdata.em_affines[i] is not None:
mat_file = tmpdir / f"init_{i_iter}_{i:05d}.mat"
dwdata.em_affines[i].to_filename(mat_file, fmt="itk")
registration.inputs.initial_moving_transform = str(mat_file)

# execute ants command line
result = registration.run(cwd=str(tmpdir)).outputs

# read output transform
xform = nt.linear.Affine(
nt.io.itk.ITKLinearTransform.from_filename(
result.forward_transforms[0]
).to_ras(reference=fixed, moving=moving),
)
# debugging: generate aligned file for testing
xform.apply(moving, reference=fixed).to_filename(
tmpdir / f"aligned{i:05d}_{int(data_test[1][3]):04d}.nii.gz"
f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{b_ix}>"
)

# Initialize the ANTs registration object for the current model iteration
xform = aligner.transform(Path(tmpdir), data_test, b_ix, predicted)

# update
dwdata.set_transform(i, xform.matrix)
dwdata.set_transform(b_ix, xform.matrix)
pbar.update()

# free memory
del xform, predicted, data_train, data_test
gc.collect()

return dwdata.em_affines


def _advanced_clip(
data, p_min=35, p_max=99.98, nonnegative=True, dtype="int16", invert=False
):
@dataclass
class Aligner:
"""Convenience dataclass that wraps and tracks ANTs registrations for each gradient prediction.

Attributes
----------
dwdata : :obj:`~eddymotion.data.DWI`
The DWI data object.
bmask_img : :obj:`str`
Path to a brain mask image.
align_kwargs : :obj:`dict`
Additional keyword arguments to pass to the ANTs registration call.
models : :obj:`list` of :obj:`str`
List of model names.

"""

dwdata: DWI
bmask_img: Optional[str]
align_kwargs: Dict
models: Union[List[str], Tuple[str]]

def set_model_iter(self, i_iter: int) -> None:
"""Set the model iteration."""
self._model_iter = i_iter

@property
def model(self) -> str:
"""Return the model name."""
return self.models[self._model_iter]

@property
def reg_target_type(self) -> str:
"""Return the registration target type."""
return (
"dwi"
if self.models[self._model_iter].lower() not in ("b0", "s0", "avg", "average", "mean")
else "b0"
)

def transform(
self, basedir: Path, data_test: np.ndarray, b_ix: int, predicted: np.ndarray
) -> nt.linear.Affine:
"""Run ANTs registration and return the resulting transform.

Parameters
----------
basedir : :obj:`pathlib.Path`
Path to a working directory.
data_test : :obj:`numpy.ndarray`
The test data.
b_ix : :obj:`int`
The index of the current gradient.
predicted : :obj:`numpy.ndarray`
The predicted dw volume for the test gradient.

"""

if self.bmask_img:
self.registration.inputs.fixed_image_masks = ["NULL", self.bmask_img]

# prepare data for running ANTs
moving = basedir / f"moving{b_ix:05d}.nii.gz"
fixed = basedir / f"fixed{b_ix:05d}.nii.gz"
_to_nifti(data_test[0], self.dwdata.affine, moving)
_to_nifti(
predicted, # generate a synthetic dw volume for the test gradient
self.dwdata.affine,
fixed,
clip=self.reg_target_type == "dwi",
)

self.registration = Registration(
terminal_output="file",
from_file=pkg_fn(
"eddymotion",
f"config/dwi-to-{self.reg_target_type}_level{self._model_iter}.json",
),
fixed_image=str(fixed.absolute()),
moving_image=str(moving.absolute()),
**self.align_kwargs,
)

if self.dwdata.em_affines and self.dwdata.em_affines[b_ix] is not None:
mat_file = basedir / f"init_{self._model_iter}_{b_ix:05d}.mat"
self.dwdata.em_affines[b_ix].to_filename(mat_file, fmt="itk")
self.registration.inputs.initial_moving_transform = str(mat_file)

# read output transform
xform = nt.linear.Affine(
nt.io.itk.ITKLinearTransform.from_filename(
self.registration.run(cwd=str(basedir)).outputs.forward_transforms[0]
).to_ras(reference=fixed, moving=moving),
)
# debugging: generate aligned file for testing
xform.apply(moving, reference=fixed).to_filename(
basedir / f"aligned{b_ix:05d}_{int(data_test[1][3]):04d}.nii.gz"
)

return xform


def _advanced_clip(data, p_min=35, p_max=99.98, nonnegative=True, dtype="int16", invert=False):
r"""
Remove outliers at both ends of the intensity distribution and fit into a given dtype.

This interface tries to emulate ANTs workflows' massaging that truncate images into
Expand All @@ -232,6 +303,9 @@ def _advanced_clip(
# Calculate stats on denoised version, to preempt outliers from biasing
denoised = ndimage.median_filter(data, footprint=ball(3))

if len(denoised[denoised > 0]) == 0:
return data

a_min = np.percentile(denoised[denoised > 0] if nonnegative else denoised, p_min)
a_max = np.percentile(denoised[denoised > 0] if nonnegative else denoised, p_max)

Expand Down
Loading