Skip to content

Commit

Permalink
Merge pull request #393 from effigies/rf/scipy-bspline-take-3
Browse files Browse the repository at this point in the history
RF: Use scipy.interpolate.BSpline to construct spline basis
  • Loading branch information
effigies authored Sep 29, 2023
2 parents c90c4ed + 1d715a9 commit 71a9ae3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 31 deletions.
14 changes: 7 additions & 7 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack
from scipy.sparse import hstack as sparse_hstack

# Output name baseline
out_name = fname_presuffix(
Expand Down Expand Up @@ -197,9 +197,9 @@ def _run_interface(self, runtime):
data -= center

# Calculate collocation matrix from (possibly resized) image and knot grids
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
colmat = sparse_hstack(
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
).tocsr()

bs_grids_str = ["x".join(str(s) for s in grid.shape) for grid in bs_grids]
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
Expand Down Expand Up @@ -254,9 +254,9 @@ def _run_interface(self, runtime):
mask = np.asanyarray(masknii.dataobj) > 1e-4
else:
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
colmat = sparse_hstack(
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
).tocsr()

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
Expand Down
6 changes: 3 additions & 3 deletions sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ def test_grid_bspline_weights():
nb.Nifti1Image(np.zeros(target_shape), target_aff),
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
).tocsr()
assert weights.shape == (64, 1000)
assert weights.shape == (1000, 64)
# Empirically determined numbers intended to indicate that something
# significant has changed. If it turns out we've been doing this wrong,
# these numbers will probably change.
assert np.isclose(weights[0, 0], 0.00089725334)
assert np.isclose(weights[-1, -1], 0.18919244)
assert np.isclose(weights.sum(axis=1).max(), 129.3907)
assert np.isclose(weights.sum(axis=1).min(), 0.0052327816)
assert np.isclose(weights.sum(axis=0).max(), 129.3907)
assert np.isclose(weights.sum(axis=0).min(), 0.0052327816)
44 changes: 23 additions & 21 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
import numpy as np
from warnings import warn
from scipy import ndimage as ndi
from scipy.signal import cubic
from scipy.sparse import vstack as sparse_vstack, kron, lil_array
from scipy.interpolate import BSpline
from scipy.sparse import hstack as sparse_hstack, kron, lil_array

import nibabel as nb
import nitransforms as nt
Expand Down Expand Up @@ -309,7 +309,6 @@ def fit(
atol=1e-3,
)

weights = []
if approx:
from sdcflows.utils.tools import deoblique_and_zooms

Expand All @@ -321,17 +320,15 @@ def fit(
)

# Generate tensor-product B-Spline weights
coeffs_data = []
for level in coeffs:
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))
colmat = sparse_hstack(
[grid_bspline_weights(projected_reference, level) for level in coeffs]
).tocsr()
coefficients = np.hstack(
[level.get_fdata(dtype="float32").reshape(-1) for level in coeffs]
)

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
fmap = (np.squeeze(np.hstack(coeffs_data).T) @ sparse_vstack(weights)).reshape(
fmap.shape
)
fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3])

# Generate a NIfTI object
hdr = target_reference.header.copy()
Expand Down Expand Up @@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
Returns
-------
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
weights : :obj:`numpy.ndarray` (:math:`N \times K`)
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
for the *N* voxels of the target EPI, for each of the total *K* knots.
This sparse matrix can be directly used as design matrix for the fitting
Expand Down Expand Up @@ -732,21 +729,26 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
coords[axis] = np.arange(sample_shape[axis], dtype=dtype)

# Calculate the index component of samples w.r.t. B-Spline knots along current axis
# Size of locations is L
locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
knots = np.arange(knots_shape[axis], dtype=dtype)

distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis])
# Size of knots is K + 6 so that all locations are fully covered by basis
knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype)

bspl = BSpline(knots, np.eye(len(knots) - 3 - 1), 3)

# Construct a sparse design matrix (L, K)
distance = np.abs(locs[..., np.newaxis] - knots[np.newaxis, 3:-3])
within_support = distance < 2.0
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
bs_w = cubic(d_vals)

colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype)
colloc_ax[within_support] = bs_w[d_idxs]
colloc_ax = lil_array(distance.shape, dtype=dtype)
colloc_ax[within_support] = bspl(locs)[:, 1:-1][within_support]

wd.append(colloc_ax)
# Convert to CSR for efficient multiplication
wd.append(colloc_ax.tocsr())

# Calculate the tensor product of the three design matrices
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
return kron(kron(wd[0], wd[1]), wd[2])


def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):
Expand Down

0 comments on commit 71a9ae3

Please sign in to comment.