Skip to content

Commit

Permalink
enh: incorporate dense resampling of fieldmap
Browse files Browse the repository at this point in the history
Requires: #355.
  • Loading branch information
oesteban committed May 2, 2023
1 parent ff6cfdf commit 88405a1
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def fit(self, target_reference, affine=None, approx=True):
will be dropped.
affine : :obj:`numpy.ndarray`
Transform that maps coordinates on the target_reference on to the
fieldmap reference.
fieldmap reference (that is, the affine through which the filedmap
be resampled in register with the target_reference).
approx : :obj:`bool`
If ``True``, do not reconstruct the B-Spline field directly on the target
(which will likely not be aligned with the fieldmap's grid), but rather use
Expand All @@ -82,6 +83,7 @@ def fit(self, target_reference, affine=None, approx=True):
if isinstance(target_reference, (str, bytes, Path)):
target_reference = nb.load(target_reference)

approx = approx if affine is not None else False # Approximate iff affine is defined
affine = affine if affine is not None else np.eye(4)
target_affine = target_reference.affine.copy()

Expand All @@ -102,15 +104,19 @@ def fit(self, target_reference, affine=None, approx=True):

weights = []
coeffs = []
_tmp_reference = None

hdr = target_reference.header.copy()
if approx:
# Keep a copy
# target_reference_bak = target_reference
from sdcflows.utils.tools import deoblique_and_zooms

_tmp_reference = target_reference
# Generate a sampling reference on the fieldmap's space that fully covers
# the target_reference's grid.
# target_reference = ...
raise NotImplementedError
target_reference = deoblique_and_zooms(
listify(self.coeffs)[-1],
target_reference,
)

# Generate tensor-product B-Spline weights
for level in listify(self.coeffs):
Expand All @@ -125,12 +131,13 @@ def fit(self, target_reference, affine=None, approx=True):
)

if approx:
from nitransforms.linear import Affine

# Interpolate fmap given on target_reference in target_reference_back
# voxel locations (overwrite fmap)
raise NotImplementedError
fmap = Affine(reference=_tmp_reference).apply(fmap)

# Cache
hdr = target_reference.header.copy()
hdr.set_intent("estimate", name="fieldmap Hz")
hdr.set_data_dtype("float32")
hdr["cal_max"] = max((abs(fmap.min()), fmap.max()))
Expand All @@ -140,7 +147,7 @@ def fit(self, target_reference, affine=None, approx=True):

def apply(
self,
data,
moving,
pe_dir,
ro_time,
xfms=None,
Expand All @@ -149,13 +156,16 @@ def apply(
cval=0.0,
prefilter=True,
output_dtype=None,
num_threads=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Handles parallelization to resample 4D images.
Parameters
----------
data : `spatialimage`
moving : `spatialimage`
The image object containing the data to be resampled in reference
space
xfms : `None` or :obj:`list`
Expand Down Expand Up @@ -187,12 +197,13 @@ def apply(
from sdcflows.utils.tools import ensure_positive_cosines

# Ensure the fmap has been computed
if isinstance(spatialimage, (str, bytes, Path)):
spatialimage = nb.load(spatialimage)
if isinstance(moving, (str, bytes, Path)):
moving = nb.load(moving)

spatialimage, axcodes = ensure_positive_cosines(spatialimage)
# TODO: not sure this is necessary - instead check it matches self.mapped.
moving, axcodes = ensure_positive_cosines(moving)

self.fit(spatialimage)
self.fit(moving)
fmap = self.mapped.get_fdata().copy()

# Reverse mapped if reversed blips
Expand All @@ -203,19 +214,19 @@ def apply(
pe_axis = "ijk".index(pe_dir[0])

# Map voxel coordinates applying the VSM
if self.xfm is None:
ijk_axis = tuple([np.arange(s) for s in fmap.shape])
voxcoords = np.array(
np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32"
).reshape(3, -1)
else:
ijk_axis = tuple([np.arange(s) for s in fmap.shape])
voxcoords = np.array(
np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32"
).reshape(3, -1)

if xfms is None:
# Map coordinates from reference to time-step
self.xfm.reference = spatialimage
hmc_xyz = self.xfm.map(self.xfm.reference.ndcoords.T)
xfms.reference = moving
hmc_xyz = xfms.map(xfms.reference.ndcoords.T)
# Convert from RAS to voxel coordinates
voxcoords = (
np.linalg.inv(self.xfm.reference.affine)
@ _as_homogeneous(np.vstack(hmc_xyz), dim=self.xfm.reference.ndim).T
np.linalg.inv(xfms.reference.affine)
@ _as_homogeneous(np.vstack(hmc_xyz), dim=xfms.reference.ndim).T
)[:3, ...]

# fmap * ro_time is the voxel-shift map (VSM)
Expand All @@ -224,7 +235,7 @@ def apply(
voxcoords[pe_axis, ...] += fmap.reshape(-1) * ro_time

# Prepare data
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
data = np.squeeze(np.asanyarray(moving.dataobj))
output_dtype = output_dtype or data.dtype

# Resample
Expand All @@ -236,10 +247,10 @@ def apply(
mode=mode,
cval=cval,
prefilter=prefilter,
).reshape(spatialimage.shape)
).reshape(moving.shape)

moved = spatialimage.__class__(
resampled, spatialimage.affine, spatialimage.header
moved = moving.__class__(
resampled, moving.affine, moving.header
)
moved.header.set_data_dtype(output_dtype)
return reorient_image(moved, axcodes)
Expand Down

0 comments on commit 88405a1

Please sign in to comment.