Skip to content

Commit

Permalink
feat: defer head-motion correction into parallel unwarp
Browse files Browse the repository at this point in the history
This commit defers the handling of head-motion correction transforms to
the separate threads.

It has been tested with transforms having random translations and
everything looks correct.

If passed a constant, zero-fieldmap (coeffs), but a list of realignment
headmotion correction matrices, then the transform should output a
realigned file.
  • Loading branch information
oesteban committed Jun 30, 2023
1 parent 55e0505 commit 19513f6
Showing 1 changed file with 56 additions and 12 deletions.
68 changes: 56 additions & 12 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,51 @@
from niworkflows.interfaces.nibabel import reorient_image


def _sdc_unwarp(
data: np.ndarray,
coordinates: np.ndarray,
hmc_xfm: np.ndarray,
voxshift: np.ndarray,
pe_axis: int,
output_dtype: Union[type, np.dtype] = None,
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
prefilter: bool = True,
) -> np.ndarray:
if hmc_xfm is not None:
# Move image with the head
coords_shape = coordinates.shape
coordinates = nb.affines.apply_affine(
hmc_xfm, coordinates.reshape(coords_shape[0], -1).T
).T.reshape(coords_shape)

# Map voxel coordinates applying the VSM
coordinates[pe_axis, ...] += voxshift

resampled = ndi.map_coordinates(
data,
coordinates,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

return resampled


async def worker(
data: np.ndarray,
coordinates: np.ndarray,
hmc_xfm: np.ndarray,
func: Callable,
semaphore: asyncio.Semaphore,
) -> np.ndarray:
async with semaphore:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, func, data, coordinates)
result = await loop.run_in_executor(None, func, data, coordinates, hmc_xfm)
return result


Expand All @@ -71,12 +107,11 @@ async def unwarp_parallel(
if fulldataset.ndim == 3:
fulldataset = fulldataset[..., np.newaxis]

# Map voxel coordinates applying the VSM
coordinates[pe_axis, ...] += voxelshift_map

func = partial(
ndi.map_coordinates,
output=output_dtype,
_sdc_unwarp,
voxshift=voxelshift_map,
pe_axis=pe_axis,
output_dtype=output_dtype,
order=order,
mode=mode,
cval=cval,
Expand All @@ -85,8 +120,11 @@ async def unwarp_parallel(

# Create a worker task for each chunk
tasks = []
for volume in np.rollaxis(fulldataset, 3, 0):
task = asyncio.create_task(worker(volume, coordinates, func, semaphore))
for volid, volume in enumerate(np.rollaxis(fulldataset, -1, 0)):
xfm = None if xfms is None else xfms[volid]

# IMPORTANT - the coordinates array must be copied every time anew per thread
task = asyncio.create_task(worker(volume, coordinates.copy(), xfm, func, semaphore))
tasks.append(task)

# Wait for all tasks to complete
Expand Down Expand Up @@ -275,6 +313,9 @@ def apply(
if isinstance(moving, (str, bytes, Path)):
moving = nb.load(moving)

# Generate warp field (before ensuring positive cosines)
self.fit(moving)

# Make sure the data array has all cosines positive (i.e., no axes are flipped)
moving, axcodes = ensure_positive_cosines(moving)

Expand All @@ -286,9 +327,6 @@ def apply(
if axis_flip ^ pe_flip:
ro_time *= -1.0

# Generate warp field
self.fit(moving)

# Prepare data
data = np.squeeze(np.asanyarray(moving.dataobj))
output_dtype = output_dtype or moving.header.get_data_dtype()
Expand All @@ -300,7 +338,13 @@ def apply(

# The VSM is just the displacements field given in index coordinates
# voxcoords is the deformation field, i.e., the target position of each voxel
vsm = self.mapped.get_fdata(dtype="float32").copy() * ro_time
vsm = self.mapped.get_fdata(dtype="float32") * ro_time

# Convert head-motion transforms to voxel-to-voxel:
if xfms is not None:
vox2ras = moving.affine.copy()
ras2vox = np.linalg.inv(vox2ras)
xfms = [ras2vox @ xfm @ vox2ras for xfm in xfms]

# Resample
resampled = asyncio.run(unwarp_parallel(
Expand Down

0 comments on commit 19513f6

Please sign in to comment.