Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Integrate downsampling in BSplineApprox when the input is high-res #301

Merged
merged 7 commits into from
Nov 18, 2022
130 changes: 93 additions & 37 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm
DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm
DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm
BSPLINE_SUPPORT = 2 - 1.82e-3 # Disallows weights < 1e-9
LOGGER = logging.getLogger("nipype.interface")


Expand Down Expand Up @@ -76,6 +75,11 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
usedefault=True,
desc="generate a field, extrapolated outside the brain mask",
)
zooms_min = traits.Float(
1.4,
oesteban marked this conversation as resolved.
Show resolved Hide resolved
usedefault=True,
desc="limit minimum image zooms, set 0.0 to use the original image",
)


class _BSplineApproxOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -123,16 +127,45 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack

# Output name baseline
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
zooms = fmapnii.header.get_zooms()

# Get a mask (or define on the spot to cover the full extent)
masknii = (
nb.load(self.inputs.in_mask)
if isdefined(self.inputs.in_mask)
else None
)

need_resize = np.any(np.array(zooms) < self.inputs.zooms_min)
if need_resize:
from niworkflows.utils.images import resample_by_spacing

LOGGER.info(
"Resampling image with resolution exceeding 'zooms_min' "
f"({'x'.join(str(s) for s in zooms)})."
)
fmapnii = resample_by_spacing(fmapnii, [self.inputs.zooms_min] * 3)
oesteban marked this conversation as resolved.
Show resolved Hide resolved

if masknii is not None:
masknii = resample_by_spacing(masknii, [self.inputs.zooms_min] * 3)
oesteban marked this conversation as resolved.
Show resolved Hide resolved

data = fmapnii.get_fdata(dtype="float32")

# Generate a numpy array with the mask
mask = (
nb.load(self.inputs.in_mask).get_fdata() > 0
if isdefined(self.inputs.in_mask)
else np.ones_like(data, dtype=bool)
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
oesteban marked this conversation as resolved.
Show resolved Hide resolved
else np.asanyarray(masknii.dataobj) > 1e-4
)

# Convert spacings to numpy arrays
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Recenter the fieldmap
Expand All @@ -145,45 +178,29 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate the spatial location of control points
bs_levels = []
ncoeff = []
weights = None
for sp in bs_spacing:
level = bspline_grid(fmapnii, control_zooms_mm=sp)
bs_levels.append(level)
ncoeff.append(level.dataobj.size)

weights = (
gbsw(fmapnii, level)
if weights is None
else sparse_vstack((weights, gbsw(fmapnii, level)))
)
# Calculate collocation matrix & the spatial location of control points
colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing)

regressors = weights.T.tocsr()[mask.reshape(-1), :]
bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels]
bs_levels_str[-1] = f"and {bs_levels_str[-1]}"
LOGGER.info(
f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of "
f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels,"
f" of which {mask.sum()} fall within the mask."
)

# Fit the model
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
model.fit(regressors, data[mask])

interp_data = np.zeros_like(data)
interp_data[mask] = np.array(model.coef_) @ regressors.T # Interpolation

# Store outputs
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name
model.fit(colmat[mask.reshape(-1), :], data[mask])

# Store coefficients
index = 0
self._results["out_coeff"] = []
for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)):
for i, bsl in enumerate(bs_levels):
n = bsl.dataobj.size
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
bsl.__class__(
np.array(model.coef_, dtype="float32")[index : index + n].reshape(
np.array(model.coef_, dtype="float32")[index:index + n].reshape(
bsl.shape
),
bsl.affine,
Expand All @@ -192,6 +209,27 @@ def _run_interface(self, runtime):
index += n
self._results["out_coeff"].append(out_level)

# Interpolating in the original grid will require a new collocation matrix
if need_resize:
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata(dtype="float32")
mask = (
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4
)
colmat, _ = _collocation_matrix(fmapnii, bs_spacing)
oesteban marked this conversation as resolved.
Show resolved Hide resolved

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
# Interpolate the field from the coefficients just calculated
interp_data[mask] = regressors @ model.coef_

# Store interpolated field
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
fmapnii.__class__(
Expand All @@ -205,8 +243,8 @@ def _run_interface(self, runtime):
self._results["out_extrapolated"] = self._results["out_field"]
return runtime

extrapolators = weights.tocsc()[:, ~mask.reshape(-1)]
interp_data[~mask] = np.array(model.coef_) @ extrapolators # Extrapolation
extrapolators = colmat[~mask.reshape(-1), :]
interp_data[~mask] = extrapolators @ model.coef_ # Extrapolation
self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(
self._results["out_extrapolated"]
Expand Down Expand Up @@ -457,6 +495,24 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _collocation_matrix(image, knot_spacing):
from scipy.sparse import vstack as sparse_vstack

bs_levels = []
weights = None
for sp in knot_spacing:
level = bspline_grid(image, control_zooms_mm=sp)
bs_levels.append(level)

weights = (
gbsw(image, level)
oesteban marked this conversation as resolved.
Show resolved Hide resolved
if weights is None
else sparse_vstack((weights, gbsw(image, level)))
)

return weights.T.tocsr(), bs_levels


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down