Skip to content

Commit

Permalink
enh(BSplineApprox): integrate downsampling when the input is high-res
Browse files Browse the repository at this point in the history
Adds a new input ``zooms_min`` that will trigger the downsampling of the
input field map when it comes with a very high resolution (typically
this is the case for SDC SyN).

Although the input is resampled, the output field is interpolated at the
original resolution for debugging purposes.

Downsampling dramatically speeds up approximation.
Alternatively, a solution (for the SDC SyN) would be to calculate a
heavily dilated mask of the brain, and use it to limit the number of
voxels that are fit in regression.
  • Loading branch information
oesteban committed Nov 16, 2022
1 parent 83b064e commit 679f09a
Showing 1 changed file with 93 additions and 37 deletions.
130 changes: 93 additions & 37 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm
DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm
DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm
BSPLINE_SUPPORT = 2 - 1.82e-3 # Disallows weights < 1e-9
LOGGER = logging.getLogger("nipype.interface")


Expand Down Expand Up @@ -76,6 +75,11 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
usedefault=True,
desc="generate a field, extrapolated outside the brain mask",
)
zooms_min = traits.Float(
1.4,
usedefault=True,
desc="limit minimum image zooms, set 0.0 to use the original image",
)


class _BSplineApproxOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -123,16 +127,45 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack

# Output name baseline
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
zooms = fmapnii.header.get_zooms()

# Get a mask (or define on the spot to cover the full extent)
masknii = (
nb.load(self.inputs.in_mask)
if isdefined(self.inputs.in_mask)
else None
)

need_resize = np.any(np.array(zooms) < self.inputs.zooms_min)
if need_resize:
from niworkflows.utils.images import resample_by_spacing

LOGGER.info(
"Resampling image with resolution exceeding 'zooms_min' "
f"({'x'.join(str(s) for s in zooms)})."
)
fmapnii = resample_by_spacing(fmapnii, [self.inputs.zooms_min] * 3)

if masknii is not None:
masknii = resample_by_spacing(masknii, [self.inputs.zooms_min] * 3)

data = fmapnii.get_fdata(dtype="float32")

# Generate a numpy array with the mask
mask = (
nb.load(self.inputs.in_mask).get_fdata() > 0
if isdefined(self.inputs.in_mask)
else np.ones_like(data, dtype=bool)
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(masknii.dataobj) > 1e-4
)

# Convert spacings to numpy arrays
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Recenter the fieldmap
Expand All @@ -145,45 +178,29 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate the spatial location of control points
bs_levels = []
ncoeff = []
weights = None
for sp in bs_spacing:
level = bspline_grid(fmapnii, control_zooms_mm=sp)
bs_levels.append(level)
ncoeff.append(level.dataobj.size)

weights = (
gbsw(fmapnii, level)
if weights is None
else sparse_vstack((weights, gbsw(fmapnii, level)))
)
# Calculate collocation matrix & the spatial location of control points
colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing)

regressors = weights.T.tocsr()[mask.reshape(-1), :]
bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels]
bs_levels_str[-1] = f"and {bs_levels_str[-1]}"
LOGGER.info(
f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a map of "
f"shape {fmapnii.dataobj.size} ({'x'.join(str(s) for s in fmapnii.shape)} voxels)"
f" of which {mask.sum()} fall within the mask."
)

# Fit the model
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
model.fit(regressors, data[mask])

interp_data = np.zeros_like(data)
interp_data[mask] = np.array(model.coef_) @ regressors.T # Interpolation

# Store outputs
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name
model.fit(colmat[mask.reshape(-1), :], data[mask])

# Store coefficients
index = 0
self._results["out_coeff"] = []
for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)):
for i, bsl in enumerate(bs_levels):
n = bsl.dataobj.size
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
bsl.__class__(
np.array(model.coef_, dtype="float32")[index : index + n].reshape(
np.array(model.coef_, dtype="float32")[index:index + n].reshape(
bsl.shape
),
bsl.affine,
Expand All @@ -192,6 +209,27 @@ def _run_interface(self, runtime):
index += n
self._results["out_coeff"].append(out_level)

# Interpolating in the original grid will require a new collocation matrix
if need_resize:
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata(dtype="float32")
mask = (
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4
)
colmat, _ = _collocation_matrix(fmapnii, bs_spacing)

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
# Interpolate the field from the coefficients just calculated
interp_data[mask] = regressors @ model.coef_

# Store interpolated field
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
fmapnii.__class__(
Expand All @@ -205,8 +243,8 @@ def _run_interface(self, runtime):
self._results["out_extrapolated"] = self._results["out_field"]
return runtime

extrapolators = weights.tocsc()[:, ~mask.reshape(-1)]
interp_data[~mask] = np.array(model.coef_) @ extrapolators # Extrapolation
extrapolators = colmat[~mask.reshape(-1), :]
interp_data[~mask] = extrapolators @ model.coef_ # Extrapolation
self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(
self._results["out_extrapolated"]
Expand Down Expand Up @@ -457,6 +495,24 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _collocation_matrix(image, knot_spacing):
from scipy.sparse import vstack as sparse_vstack

bs_levels = []
weights = None
for sp in knot_spacing:
level = bspline_grid(image, control_zooms_mm=sp)
bs_levels.append(level)

weights = (
gbsw(image, level)
if weights is None
else sparse_vstack((weights, gbsw(image, level)))
)

return weights.T.tocsr(), bs_levels


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down

0 comments on commit 679f09a

Please sign in to comment.