diff --git a/test/test_integration.py b/test/test_integration.py index 11753e94..1b1ba85a 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -30,13 +30,15 @@ from eddymotion.data.dmri import DWI from eddymotion.estimator import EddyMotionEstimator +from eddymotion.registration.utils import displacements_within_mask -def test_proximity_estimator_trivial_model(datadir): +def test_proximity_estimator_trivial_model(datadir, tmp_path): """Check the proximity of transforms estimated by the estimator with a trivial B0 model.""" dwdata = DWI.from_filename(datadir / "dwi.h5") b0nii = nb.Nifti1Image(dwdata.bzero, dwdata.affine, None) + masknii = nb.Nifti1Image(dwdata.brainmask.astype(np.uint8), dwdata.affine, None) # Generate a list of large-yet-plausible bulk-head motion. xfms = nt.linear.LinearTransformsMapping( @@ -56,8 +58,8 @@ def test_proximity_estimator_trivial_model(datadir): moved_nii = (~xfms).apply(b0nii, reference=b0nii) # Uncomment to see the moved dataset - # moved_nii.to_filename(tmp_path / "test.nii.gz") - # xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz") + moved_nii.to_filename(tmp_path / "test.nii.gz") + xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz") # Wrap into dataset object dwi_motion = DWI( @@ -70,7 +72,7 @@ def test_proximity_estimator_trivial_model(datadir): estimator = EddyMotionEstimator() em_affines = estimator.estimate( - dwdata=dwi_motion, + data=dwi_motion, models=("b0",), seed=None, align_kwargs={ @@ -81,14 +83,18 @@ def test_proximity_estimator_trivial_model(datadir): ) # Uncomment to see the realigned dataset - # nt.linear.LinearTransformsMapping( - # em_affines, - # reference=b0nii, - # ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz") + nt.linear.LinearTransformsMapping( + em_affines, + reference=b0nii, + ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz") # For each moved b0 volume - coords = xfms.reference.ndcoords.T for i, est in enumerate(em_affines): - xfm = nt.linear.Affine(xfms.matrix[i], reference=b0nii) - est = nt.linear.Affine(est, reference=b0nii) - assert np.sqrt(((xfm.map(coords) - est.map(coords)) ** 2).sum(1)).mean() < 0.2 + assert ( + displacements_within_mask( + masknii, + nt.linear.Affine(est), + xfms[i], + ).max() + < 0.2 + )