diff --git a/dmriprep/utils/register.py b/dmriprep/utils/register.py index 56b47084..816d5125 100644 --- a/dmriprep/utils/register.py +++ b/dmriprep/utils/register.py @@ -1,6 +1,8 @@ """ Linear affine registration tools for motion correction. """ +import attr + import numpy as np import nibabel as nb from dipy.align.metrics import CCMetric, EMMetric, SSDMetric @@ -72,7 +74,7 @@ def c_of_mass( ): transform = transform_centers_of_mass(static, static_affine, moving, moving_affine) transformed = transform.transform(moving) - return transformed, transform.affine + return transform def translation( @@ -89,7 +91,7 @@ def translation( starting_affine=starting_affine, ) - return translation.transform(moving), translation.affine + return translation def rigid( @@ -105,12 +107,13 @@ def rigid( moving_affine, starting_affine=starting_affine, ) - return rigid.transform(moving), rigid.affine + return rigid -def affine( - moving, static, static_affine, moving_affine, reg, starting_affine, params0=None -): +def affine(moving, static, static_affine, moving_affine, reg, starting_affine, + params0=None): + """ + """ transform = AffineTransform3D() affine = reg.optimize( static, @@ -122,49 +125,62 @@ def affine( starting_affine=starting_affine, ) - return affine.transform(moving), affine.affine - - -def affine_registration( - moving, - static, - nbins, - sampling_prop, - metric, - pipeline, - level_iters, - sigmas, - factors, - params0, -): - - """ - Find the affine transformation between two 3D images. - - Parameters - ---------- - - """ - # Define the Affine registration object we'll use with the chosen metric: - use_metric = affine_metric_dict[metric](nbins, sampling_prop) - affreg = AffineRegistration( - metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors - ) - - if not params0: - starting_affine = np.eye(4) - else: - starting_affine = params0 - - # Go through the selected transformation: - for func in pipeline: - transformed, starting_affine = func( - np.asarray(moving.dataobj), - np.asarray(static.dataobj), - static.affine, - moving.affine, - affreg, - starting_affine, - params0, - ) - return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine + return affine + + +@attr.s(slots=True, frozen=True) +class AffineRegistration(): + def __init__(self): + nbins = attr.ib(default=32) + sampling_prop = attr.ib(default=1.0) + metric = attr.ib(default="MI") + level_iters = attr.ib(default=[10000, 1000, 100]) + sigmas = attr.ib(defaults=[3, 1, 0.0]) + factors = attr.ib(defaults=[4, 2, 1]) + pipeline = attr.ib(defaults=[c_of_mass, translation, rigid, affine]) + + def fit(self, static, moving, params0=None): + """ + static, moving : nib.Nifti1Image class images + """ + if params0 is None: + starting_affine = np.eye(4) + else: + starting_affine = params0 + + use_metric = affine_metric_dict[self.metric](self.nbins, + self.sampling_prop) + affreg = AffineRegistration( + metric=use_metric, + level_iters=self.level_iters, + sigmas=self.sigmas, + factors=self.factors) + + # Go through the selected transformation: + for func in self.pipeline: + transform = func( + np.asarray(moving.dataobj), + np.asarray(static.dataobj), + static.affine, + moving.affine, + affreg, + starting_affine, + params0, + ) + starting_affine = transform.affine + + self.static_affine_ = static.affine + self.moving_affine_ = moving.affine + self.affine_ = starting_affine + self.reg_ = AffineMap(starting_affine, + static.shape, static.affine, + moving.shape, moving.affine) + + def apply(self, moving): + """ + + """ + data = moving.get_fdata() + assert np.all(moving.affine, self.moving_affine_) + return nb.Nifti1Image(np.array(self.reg_.transform(data)), + self.static_affine_)