Skip to content

Commit

Permalink
ref: cleanup code in preparation of head motion
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jun 30, 2023
1 parent 50a8e64 commit 55e0505
Showing 1 changed file with 60 additions and 64 deletions.
124 changes: 60 additions & 64 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,43 +36,64 @@

import nibabel as nb
import nitransforms as nt
from nitransforms.base import _as_homogeneous
from bids.utils import listify

from niworkflows.interfaces.nibabel import reorient_image


async def worker(data: np.ndarray, coordinates: np.ndarray, func: Callable) -> np.ndarray:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, func, data, coordinates)
return result
async def worker(
data: np.ndarray,
coordinates: np.ndarray,
func: Callable,
semaphore: asyncio.Semaphore,
) -> np.ndarray:
async with semaphore:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, func, data, coordinates)
return result


async def map_coordinates_thread_pool(
async def unwarp_parallel(
fulldataset: np.ndarray,
coordinates: np.ndarray,
num_workers: int,
func: Callable = ndi.map_coordinates,
voxelshift_map: np.ndarray,
pe_axis: int,
xfms: Sequence[np.array],
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
prefilter: bool = True,
output_dtype: Union[str, np.dtype] = None,
max_concurrent: int = min(os.cpu_count(), 12),
) -> List[np.ndarray]:
results = []
tasks = []
semaphore = asyncio.Semaphore(max_concurrent)

if fulldataset.ndim == 3:
fulldataset = fulldataset[..., np.newaxis]

out_shape = fulldataset.shape[:-1]
out_dtype = fulldataset.dtype
# Map voxel coordinates applying the VSM
coordinates[pe_axis, ...] += voxelshift_map

func = partial(
ndi.map_coordinates,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

# Create a worker task for each chunk
tasks = []
for volume in np.rollaxis(fulldataset, 3, 0):
task = asyncio.create_task(worker(volume, coordinates, func))
task = asyncio.create_task(worker(volume, coordinates, func, semaphore))
tasks.append(task)

# Wait for all tasks to complete
await asyncio.gather(*tasks)

# Collect the results an
results = np.rollaxis(np.array([
np.array(task.result(), dtype=out_dtype).reshape(out_shape)
for task in tasks
]), 0, 4)
# Collect the results and stack along last dimension
results = np.stack([task.result() for task in tasks], -1)

return results

Expand Down Expand Up @@ -257,69 +278,44 @@ def apply(
# Make sure the data array has all cosines positive (i.e., no axes are flipped)
moving, axcodes = ensure_positive_cosines(moving)

self.fit(moving)
fmap = self.mapped.get_fdata().copy()

# Generate warp field
pe_axis = "ijk".index(pe_dir[0])

axis_flip = axcodes[pe_axis] in ("LPI")
pe_flip = pe_dir.endswith("-")

# Displacements are reversed if either is true (after ensuring positive cosines)
if axis_flip ^ pe_flip:
fmap *= -1.0

# Map voxel coordinates applying the VSM
ijk_axis = tuple([np.arange(s) for s in fmap.shape])
voxcoords = np.array(
np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32"
).reshape(3, -1)

# TODO: we probably want to do this within each resampling thread
if xfms is not None:
mov_ras2vox = np.linalg.inv(moving.affine)
# Map coordinates from reference to time-step
xfms.reference = moving
hmc_xyz = xfms.map(xfms.reference.ndcoords.T)
# Convert from RAS to voxel coordinates
voxcoords = (
mov_ras2vox
@ _as_homogeneous(np.vstack(hmc_xyz), dim=xfms.reference.ndim).T
)[:3, ...]

# fmap * ro_time is the voxel-shift map (VSM)
# The VSM is just the displacements field given in index coordinates
# voxcoords is the deformation field, i.e., the target position of each voxel
voxcoords[pe_axis, ...] += fmap.reshape(-1) * ro_time
ro_time *= -1.0

# Generate warp field
self.fit(moving)

# Prepare data
data = np.squeeze(np.asanyarray(moving.dataobj))
output_dtype = output_dtype or moving.header.get_data_dtype()

if data.ndim == 3:
data = data[..., np.newaxis]
# Reference image's voxel coordinates (in voxel units)
voxcoords = nt.linear.Affine(
reference=moving
).reference.ndindex.reshape((data.ndim - 1, *data.shape[:-1])).astype("float32")

output_dtype = output_dtype or moving.header.get_data_dtype()
# The VSM is just the displacements field given in index coordinates
# voxcoords is the deformation field, i.e., the target position of each voxel
vsm = self.mapped.get_fdata(dtype="float32").copy() * ro_time

# Resample
map_coordinates = partial(
ndi.map_coordinates,
output=output_dtype,
resampled = asyncio.run(unwarp_parallel(
data,
voxcoords,
vsm,
pe_axis,
xfms,
output_dtype=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

resampled = np.array(
asyncio.run(map_coordinates_thread_pool(
data,
voxcoords,
num_threads,
map_coordinates
)),
dtype=output_dtype,
).reshape(moving.shape)
max_concurrent=num_threads,
))

if not allow_negative:
resampled[resampled < 0] = cval
Expand Down

0 comments on commit 55e0505

Please sign in to comment.