From 5e4bfd9f32a2441a37bd601ff829c47f0bc03a1d Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 27 Oct 2022 17:42:05 +0200 Subject: [PATCH] wip: cornering the problem interpolating the field --- sdcflows/interfaces/bspline.py | 31 +++++-- sdcflows/transform.py | 112 +++++++++++++++-------- sdcflows/workflows/apply/correction.py | 4 +- sdcflows/workflows/apply/registration.py | 7 ++ 4 files changed, 107 insertions(+), 47 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 46c27db90e..92bf9fee93 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -269,6 +269,7 @@ def _run_interface(self, runtime): unwarp = None hmc_mats = [None] * n + if isdefined(self.inputs.in_xfms): # Split ITK matrices in separate files if they come collated hmc_mats = ( @@ -334,6 +335,7 @@ class _TransformCoefficientsInputSpec(BaseInterfaceInputSpec): ) fmap_ref = File(exists=True, mandatory=True, desc="the fieldmap reference") transform = File(exists=True, mandatory=True, desc="rigid-body transform file") + fmap_target = File(exists=True, desc="the distorted EPI target (feed to set debug mode on)") class _TransformCoefficientsOutputSpec(TraitedSpec): @@ -356,12 +358,17 @@ def _run_interface(self, runtime): level, self.inputs.fmap_ref, self.inputs.transform, + fmap_target=( + self.inputs.fmap_target if isdefined(self.inputs.fmap_target) + else None + ), ) out_file = fname_presuffix( level, suffix="_space-target", newpath=runtime.cwd ) movednii.to_filename(out_file) self._results["out_coeff"].append(out_file) + return runtime @@ -512,6 +519,12 @@ def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None): header = coeffnii.header.copy() header.set_qform(newaff, code=1) header.set_sform(newaff, code=1) + header["cal_max"] = max(( + abs(np.asanyarray(coeffnii.dataobj).min()), + np.asanyarray(coeffnii.dataobj).max(), + )) + header["cal_min"] = - header["cal_max"] + header.set_intent("estimate", tuple(), name="B-Spline coefficients") # Write out fixed (generalized) coefficients coeffnii.__class__(coeffnii.dataobj, newaff, header).to_filename(out_file) @@ -537,8 +550,8 @@ def _chunks(inlist, chunksize): def _b0_resampler(in_file, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None): """Outsource the resampler into a separate callable function to allow parallelization.""" from functools import partial - from niworkflows.interfaces.nibabel import reorient_image - from sdcflows.utils.tools import ensure_positive_cosines + # from niworkflows.interfaces.nibabel import reorient_image + # from sdcflows.utils.tools import ensure_positive_cosines # Prepare output names filename = partial(fname_presuffix, newpath=newpath) @@ -558,19 +571,17 @@ def _b0_resampler(in_file, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=No unwarp.xfm = Affine(XFMLoader.from_filename(hmc_xfm).to_ras()) - # Reorient input to match that of the coefficients, i.e., to have positive director cosines - reoriented_img, axcodes = ensure_positive_cosines(nb.load(in_file)) + # Load distorted image + distorted_img = nb.load(in_file) - if unwarp.fit(reoriented_img): + if unwarp.fit(distorted_img): unwarp.mapped.to_filename(retval[2]) else: retval[2] = None - # Unwarp the reoriented image, and restore original orientation - unwarped_img = reorient_image( - unwarp.apply(reoriented_img, ro_time=ro, pe_dir=pe), - axcodes, - ) + # Unwarp + unwarped_img = unwarp.apply(distorted_img, ro_time=ro, pe_dir=pe) + # Write out to disk unwarped_img.to_filename(retval[0]) diff --git a/sdcflows/transform.py b/sdcflows/transform.py index d202a33fe7..de36dd827f 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -34,7 +34,6 @@ from bids.utils import listify from niworkflows.interfaces.nibabel import reorient_image -from sdcflows.utils.tools import ensure_positive_cosines def _clear_mapped(instance, attribute, value): @@ -48,7 +47,7 @@ class B0FieldTransform: coeffs = attr.ib(default=None) """B-Spline coefficients (one value per control point).""" - xfm = attr.ib(default=nt.linear.Affine(), on_setattr=_clear_mapped) + xfm = attr.ib(default=None, on_setattr=_clear_mapped) """A rigid-body transform to prepend to the unwarping displacements field.""" mapped = attr.ib(default=None, init=False) """ @@ -83,29 +82,23 @@ def fit(self, spatialimage): ): return False - weights = [] - coeffs = [] - + # Initialize a zero-displacement vsm + vsm = np.zeros(spatialimage.shape[:3], dtype="float32") # Generate tensor-product B-Spline weights for level in listify(self.coeffs): - self.xfm.reference = spatialimage - moved_cs = level.__class__( - level.dataobj, self.xfm.matrix @ level.affine, level.header - ) - wmat = grid_bspline_weights(spatialimage, moved_cs) - weights.append(wmat) - coeffs.append(level.get_fdata(dtype="float32").reshape(-1)) - - # Interpolate the VSM (voxel-shift map) - vsm = np.zeros(spatialimage.shape[:3], dtype="float32") - vsm = (np.squeeze(np.hstack(coeffs).T) @ sparse_vstack(weights)).reshape( - vsm.shape - ) + wmat = grid_bspline_weights(spatialimage, level) + coeffs = level.get_fdata(dtype="float32").reshape(-1) + + # Interpolate the VSM (voxel-shift map) + vsm += (coeffs.T @ wmat).reshape(vsm.shape) # Cache - self.mapped = nb.Nifti1Image(vsm, spatialimage.affine, None) - self.mapped.header.set_intent("estimate", name="Voxel shift") - self.mapped.header.set_xyzt_units(*spatialimage.header.get_xyzt_units()) + hdr = spatialimage.header.copy() + hdr.set_intent("estimate", name="Voxel shift") + hdr.set_data_dtype("float32") + hdr["cal_max"] = max((abs(vsm.min()), vsm.max())) + hdr["cal_min"] = - hdr["cal_max"] + self.mapped = nb.Nifti1Image(vsm, spatialimage.affine, hdr) return True def apply( @@ -153,6 +146,8 @@ def apply( The data imaged after resampling to reference space. """ + from sdcflows.utils.tools import ensure_positive_cosines + # Ensure the fmap has been computed if isinstance(spatialimage, (str, bytes, Path)): spatialimage = nb.load(spatialimage) @@ -177,6 +172,7 @@ def apply( ).reshape(3, -1) else: # Map coordinates from reference to time-step + self.xfm.reference = spatialimage hmc_xyz = self.xfm.map(self.xfm.reference.ndcoords.T) # Convert from RAS to voxel coordinates voxcoords = ( @@ -323,6 +319,11 @@ def disp_to_fmap(xyz_nii, ro_time, pe_dir, itk_format=True): fmap_nii = nb.Nifti1Image(vsm / scale_factor, xyz_nii.affine) fmap_nii.header.set_intent("estimate", name="Delta_B0 [Hz]") fmap_nii.header.set_xyzt_units("mm") + fmap_nii.header["cal_max"] = max(( + abs(np.asanyarray(fmap_nii.dataobj).min()), + np.asanyarray(fmap_nii.dataobj).max(), + )) + fmap_nii.header["cal_min"] = - fmap_nii.header["cal_max"] return fmap_nii @@ -338,9 +339,9 @@ def _cubic_bspline(d): ) -def grid_bspline_weights(target_nii, ctrl_nii): +def grid_bspline_weights(target_nii, ctrl_nii, dtype="float16"): r""" - Evaluate tensor-product B-Spline weights on a grid. + Evaluate tensor-product B-Spline weights on a grid aligned to that of the knots. For each of the *N* input samples :math:`(s_1, s_2, s_3)` and *K* control points or *knots* :math:`\mathbf{k} =(k_1, k_2, k_3)`, the tensor-product @@ -384,35 +385,74 @@ def grid_bspline_weights(target_nii, ctrl_nii): step of approximation/extrapolation. """ - shape = target_nii.shape[:3] - ctrl_sp = ctrl_nii.header.get_zooms()[:3] - ras2ijk = np.linalg.inv(ctrl_nii.affine) - origin = nb.affines.apply_affine(ras2ijk, [tuple(target_nii.affine[:3, 3])])[0] + sample_shape = target_nii.shape[:3] + knots_shape = ctrl_nii.shape[:3] + + # Ensure the cross-product of affines is near zero (i.e., both coordinate systems are aligned) + if not np.allclose(np.linalg.norm( + np.cross(ctrl_nii.affine[:-1, :-1].T, target_nii.affine[:-1, :-1].T), + axis=1, + ), 0, atol=1e-3): + raise RuntimeError("Image's and B-Spline's grids are not aligned.") + + target_to_grid = np.linalg.inv(ctrl_nii.affine) @ target_nii.affine wd = [] - for i, (o, n, sp) in enumerate( - zip(origin, shape, target_nii.header.get_zooms()[:3]) - ): - locations = np.arange(0, n, dtype="float16") * sp / ctrl_sp[i] + o - knots = np.arange(0, ctrl_nii.shape[i], dtype="float16") - distance = np.abs(locations[np.newaxis, ...] - knots[..., np.newaxis]) + for axis in range(3): + # 3D ijk coordinates of current axis + coords = np.vstack(( + np.zeros((3, sample_shape[axis]), dtype=dtype), + np.ones((1, sample_shape[axis]), dtype=dtype), + )) + coords[axis] = np.arange(sample_shape[axis], dtype=dtype) + + # Calculate the index component of samples w.r.t. B-Spline knots along current axis + ijk_samples = (target_to_grid @ coords)[axis] + # Indexes along current axis of the knots + ijk_knots = np.arange(0, knots_shape[axis], dtype=dtype) + + # Distances of every sample w.r.t. every knot + distance = np.abs(ijk_samples[np.newaxis, ...] - ijk_knots[..., np.newaxis]) + + # Optimization: calculate B-Splines only for samples within the kernel spread within_support = distance < 2.0 d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True) + + # Calculate univariate B-Spline weight for samples within kernel spread bs_w = _cubic_bspline(d_vals) + + # Densify and build csr matrix (probably, we could avoid densifying) weights = np.zeros_like(distance, dtype="float32") weights[within_support] = bs_w[d_idxs] wd.append(csr_matrix(weights)) + # Efficiently calculate tensor product of the weights return kron(kron(wd[0], wd[1]), wd[2]) -def _move_coeff(in_coeff, fmap_ref, transform): +def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None): """Read in a rigid transform from ANTs, and update the coefficients field affine.""" xfm = nt.linear.Affine( nt.io.itk.ITKLinearTransform.from_filename(transform).to_ras(), reference=fmap_ref, ) coeff = nb.load(in_coeff) - newaff = xfm.matrix @ coeff.affine - return coeff.__class__(coeff.dataobj, newaff, coeff.header) + hdr = coeff.header.copy() + + if fmap_target is not None: # Debug mode + nii_target = nb.load(fmap_target) + debug_ref = (~xfm).apply(fmap_ref, reference=nii_target) + debug_ref.header.set_qform(nii_target.affine, code=1) + debug_ref.header.set_sform(nii_target.affine, code=1) + debug_ref.to_filename(Path() / "debug_fmapref.nii.gz") + + newaff = np.linalg.inv(np.linalg.inv(coeff.affine) @ (~xfm).matrix) + hdr.set_qform(newaff, code=1) + hdr.set_sform(newaff, code=1) + hdr["cal_max"] = max(( + abs(np.asanyarray(coeff.dataobj).min()), + np.asanyarray(coeff.dataobj).max(), + )) + hdr["cal_min"] = - hdr["cal_max"] + return coeff.__class__(coeff.dataobj, newaff, hdr) diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 91380b6492..2ec7de8555 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -103,7 +103,9 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): rotime = pe.Node(GetReadoutTime(), name="rotime") rotime.interface._always_run = debug - resample = pe.Node(ApplyCoeffsField(num_threads=omp_nthreads), name="resample") + resample = pe.Node(ApplyCoeffsField( + num_threads=omp_nthreads if not debug else 1 + ), name="resample") merge = pe.Node(MergeSeries(), name="merge") average = pe.Node(RobustAverage(mc_method=None), name="average") diff --git a/sdcflows/workflows/apply/registration.py b/sdcflows/workflows/apply/registration.py index fb344f46e9..0e90efdf84 100644 --- a/sdcflows/workflows/apply/registration.py +++ b/sdcflows/workflows/apply/registration.py @@ -156,4 +156,11 @@ def init_coeff2epi_wf( ]) # fmt: on + if debug: + # fmt: off + workflow.connect([ + (inputnode, map_coeff, [("target_ref", "fmap_target")]), + ]) + # fmt: on + return workflow