Skip to content

Commit

Permalink
enh: draft the approx branch
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed May 2, 2023
1 parent 740fa6c commit 2fbad48
Showing 1 changed file with 37 additions and 21 deletions.
58 changes: 37 additions & 21 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fit(self, target_reference, affine=None, approx=True):
The image object containing a reference grid (same as that of the data
to be resampled). If a 4D dataset is provided, then the fourth dimension
will be dropped.
affine : :obj:`nitransforms.linear.Affine`
affine : :obj:`numpy.ndarray`
Transform that maps coordinates on the target_reference on to the
fieldmap reference.
approx : :obj:`bool`
Expand All @@ -79,47 +79,63 @@ def fit(self, target_reference, affine=None, approx=True):
"""
# Calculate the physical coordinates of target grid
if isinstance(spatialimage, (str, bytes, Path)):
spatialimage = nb.load(spatialimage)
if isinstance(target_reference, (str, bytes, Path)):
target_reference = nb.load(target_reference)

# Initialize xform (or set identity)
xfm = self.xfm if self.xfm is not None else nt.linear.Affine()
affine = affine if affine is not None else np.eye(4)
target_affine = target_reference.affine.copy()

# Project the reference's grid onto the fieldmap's
target_reference = target_reference.__class__(
target_reference.dataobj,
affine @ target_affine.T,
target_reference.header,
)

if self.mapped is not None:
newaff = spatialimage.affine
newshape = spatialimage.shape
newshape = target_reference.shape

if np.all(newshape == self.mapped.shape) and np.allclose(
newaff, self.mapped.affine
target_affine, self.mapped.affine
):
return False

weights = []
coeffs = []

if approx:
# Keep a copy
# target_reference_bak = target_reference

# Generate a sampling reference on the fieldmap's space that fully covers
# the target_reference's grid.
# target_reference = ...
raise NotImplementedError

# Generate tensor-product B-Spline weights
for level in listify(self.coeffs):
xfm.reference = spatialimage
moved_cs = level.__class__(
level.dataobj, xfm.matrix @ level.affine, level.header
)
wmat = grid_bspline_weights(spatialimage, moved_cs)
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)
coeffs.append(level.get_fdata(dtype="float32").reshape(-1))

# Interpolate the VSM (voxel-shift map)
vsm = np.zeros(spatialimage.shape[:3], dtype="float32")
vsm = (np.squeeze(np.hstack(coeffs).T) @ sparse_vstack(weights)).reshape(
vsm.shape
# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(target_reference.shape[:3], dtype="float32")
fmap = (np.squeeze(np.hstack(coeffs).T) @ sparse_vstack(weights)).reshape(
fmap.shape
)

if approx:
# Interpolate fmap given on target_reference in target_reference_back
# voxel locations (overwrite fmap)
raise NotImplementedError

# Cache
hdr = spatialimage.header.copy()
hdr.set_intent("estimate", name="Voxel shift")
hdr = target_reference.header.copy()
hdr.set_intent("estimate", name="fieldmap Hz")
hdr.set_data_dtype("float32")
hdr["cal_max"] = max((abs(vsm.min()), vsm.max()))
hdr["cal_max"] = max((abs(fmap.min()), fmap.max()))
hdr["cal_min"] = - hdr["cal_max"]
self.mapped = nb.Nifti1Image(vsm, spatialimage.affine, hdr)
self.mapped = nb.Nifti1Image(fmap, target_affine, hdr)
return True

def apply(
Expand Down

0 comments on commit 2fbad48

Please sign in to comment.