From ced8001450e5fb4e3790bda7ab2d297e1309e20e Mon Sep 17 00:00:00 2001 From: Chris Markiewicz Date: Mon, 5 Dec 2022 15:56:15 -0500 Subject: [PATCH] RF: Calculate bspline grids separately from colocation matrices --- sdcflows/interfaces/bspline.py | 42 ++++++++++++---------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 31ecad7b1a..132d00be5a 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -130,6 +130,7 @@ class BSplineApprox(SimpleInterface): def _run_interface(self, runtime): from sklearn import linear_model as lm + from scipy.sparse import vstack as sparse_vstack # Output name baseline out_name = fname_presuffix( @@ -147,6 +148,10 @@ def _run_interface(self, runtime): else None ) + # Determine the shape of bspline coefficients + # This should not change with resizing, so do it first + bs_grids = [bspline_grid(fmapnii, control_zooms_mm=sp) for sp in self.inputs.bs_spacing] + need_resize = np.any(np.array(zooms) < self.inputs.zooms_min) if need_resize: from sdcflows.utils.tools import resample_to_zooms @@ -171,9 +176,6 @@ def _run_interface(self, runtime): else np.asanyarray(masknii.dataobj) > 1e-4 ) - # Convert spacings to numpy arrays - bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing] - # Recenter the fieldmap if self.inputs.recenter == "mode": from scipy.stats import mode @@ -184,13 +186,13 @@ def _run_interface(self, runtime): elif self.inputs.recenter == "mean": data -= np.mean(data[mask]) - # Calculate collocation matrix & the spatial location of control points - colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing) + # Calculate collocation matrix from (possibly resized) image and knot grids + colmat = sparse_vstack(grid_bspline_weights(fmapnii, grid) for grid in bs_grids).T.tocsr() - bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels] - bs_levels_str[-1] = f"and {bs_levels_str[-1]}" + bs_grids_str = ['x'.join(str(s) for s in grid.shape) for grid in bs_grids] + bs_grids_str[-1] = f"and {bs_grids_str[-1]}" LOGGER.info( - f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of " + f"Approximating B-Splines grids ({', '.join(bs_grids_str)} [knots]) on a grid of " f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels," f" of which {mask.sum()} fall within the mask." ) @@ -202,7 +204,7 @@ def _run_interface(self, runtime): # Store coefficients index = 0 self._results["out_coeff"] = [] - for i, bsl in enumerate(bs_levels): + for i, bsl in enumerate(bs_grids): n = bsl.dataobj.size out_level = out_name.replace("_field.", f"_coeff{i:03}.") bsl.__class__( @@ -223,7 +225,9 @@ def _run_interface(self, runtime): np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4 ) - colmat, _ = _collocation_matrix(fmapnii, bs_spacing) + colmat = sparse_vstack( + grid_bspline_weights(fmapnii, grid) for grid in bs_grids + ).T.tocsr() regressors = colmat[mask.reshape(-1), :] interp_data = np.zeros_like(data) @@ -506,24 +510,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine) -def _collocation_matrix(image, knot_spacing): - from scipy.sparse import vstack as sparse_vstack - - bs_levels = [] - weights = None - for sp in knot_spacing: - level = bspline_grid(image, control_zooms_mm=sp) - bs_levels.append(level) - - weights = ( - grid_bspline_weights(image, level) - if weights is None - else sparse_vstack((weights, grid_bspline_weights(image, level))) - ) - - return weights.T.tocsr(), bs_levels - - def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None): """Read in a coefficients file generated by TOPUP and fix x-form headers.""" from pathlib import Path