Skip to content

Commit

Permalink
fix: parallelize resampling of 4D images
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jun 30, 2023
1 parent fe4faf1 commit 50a8e64
Showing 1 changed file with 70 additions and 11 deletions.
81 changes: 70 additions & 11 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
# https://www.nipreps.org/community/licensing/
#
"""The :math:`B_0` unwarping transform formalism."""
import os
from functools import partial
import asyncio
from pathlib import Path
from typing import Sequence, Union
from typing import Callable, List, Sequence, Union

import attr
import numpy as np
Expand All @@ -39,6 +42,41 @@
from niworkflows.interfaces.nibabel import reorient_image


async def worker(data: np.ndarray, coordinates: np.ndarray, func: Callable) -> np.ndarray:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, func, data, coordinates)
return result


async def map_coordinates_thread_pool(
fulldataset: np.ndarray,
coordinates: np.ndarray,
num_workers: int,
func: Callable = ndi.map_coordinates,
) -> List[np.ndarray]:
results = []
tasks = []

out_shape = fulldataset.shape[:-1]
out_dtype = fulldataset.dtype

# Create a worker task for each chunk
for volume in np.rollaxis(fulldataset, 3, 0):
task = asyncio.create_task(worker(volume, coordinates, func))
tasks.append(task)

# Wait for all tasks to complete
await asyncio.gather(*tasks)

# Collect the results an
results = np.rollaxis(np.array([
np.array(task.result(), dtype=out_dtype).reshape(out_shape)
for task in tasks
]), 0, 4)

return results


@attr.s(slots=True)
class B0FieldTransform:
"""Represents and applies the transform to correct for susceptibility distortions."""
Expand Down Expand Up @@ -164,7 +202,8 @@ def apply(
cval: float = 0.0,
prefilter: bool = True,
output_dtype: Union[str, np.dtype] = None,
# num_threads: int = None,
num_threads: int = os.cpu_count(),
allow_negative: bool = False,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Expand Down Expand Up @@ -215,25 +254,29 @@ def apply(
if isinstance(moving, (str, bytes, Path)):
moving = nb.load(moving)

# TODO: not sure this is necessary - instead check it matches self.mapped.
# Make sure the data array has all cosines positive (i.e., no axes are flipped)
moving, axcodes = ensure_positive_cosines(moving)

self.fit(moving)
fmap = self.mapped.get_fdata().copy()

# Reverse mapped if reversed blips
if pe_dir.endswith("-"):
fmap *= -1.0

# Generate warp field
pe_axis = "ijk".index(pe_dir[0])

axis_flip = axcodes[pe_axis] in ("LPI")
pe_flip = pe_dir.endswith("-")

# Displacements are reversed if either is true (after ensuring positive cosines)
if axis_flip ^ pe_flip:
fmap *= -1.0

# Map voxel coordinates applying the VSM
ijk_axis = tuple([np.arange(s) for s in fmap.shape])
voxcoords = np.array(
np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32"
).reshape(3, -1)

# TODO: we probably want to do this within each resampling thread
if xfms is not None:
mov_ras2vox = np.linalg.inv(moving.affine)
# Map coordinates from reference to time-step
Expand All @@ -252,19 +295,35 @@ def apply(

# Prepare data
data = np.squeeze(np.asanyarray(moving.dataobj))
output_dtype = output_dtype or data.dtype

if data.ndim == 3:
data = data[..., np.newaxis]

output_dtype = output_dtype or moving.header.get_data_dtype()

# Resample
resampled = ndi.map_coordinates(
data,
voxcoords,
map_coordinates = partial(
ndi.map_coordinates,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

resampled = np.array(
asyncio.run(map_coordinates_thread_pool(
data,
voxcoords,
num_threads,
map_coordinates
)),
dtype=output_dtype,
).reshape(moving.shape)

if not allow_negative:
resampled[resampled < 0] = cval

moved = moving.__class__(resampled, moving.affine, moving.header)
moved.header.set_data_dtype(output_dtype)
return reorient_image(moved, axcodes)
Expand Down

0 comments on commit 50a8e64

Please sign in to comment.