diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 9bd2b07881..3dce1d3baf 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -41,15 +41,51 @@ from niworkflows.interfaces.nibabel import reorient_image +def _sdc_unwarp( + data: np.ndarray, + coordinates: np.ndarray, + hmc_xfm: np.ndarray, + voxshift: np.ndarray, + pe_axis: int, + output_dtype: Union[type, np.dtype] = None, + order: int = 3, + mode: str = "constant", + cval: float = 0.0, + prefilter: bool = True, +) -> np.ndarray: + if hmc_xfm is not None: + # Move image with the head + coords_shape = coordinates.shape + coordinates = nb.affines.apply_affine( + hmc_xfm, coordinates.reshape(coords_shape[0], -1).T + ).T.reshape(coords_shape) + + # Map voxel coordinates applying the VSM + coordinates[pe_axis, ...] += voxshift + + resampled = ndi.map_coordinates( + data, + coordinates, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + + return resampled + + async def worker( data: np.ndarray, coordinates: np.ndarray, + hmc_xfm: np.ndarray, func: Callable, semaphore: asyncio.Semaphore, ) -> np.ndarray: async with semaphore: loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, func, data, coordinates) + result = await loop.run_in_executor(None, func, data, coordinates, hmc_xfm) return result @@ -71,12 +107,11 @@ async def unwarp_parallel( if fulldataset.ndim == 3: fulldataset = fulldataset[..., np.newaxis] - # Map voxel coordinates applying the VSM - coordinates[pe_axis, ...] += voxelshift_map - func = partial( - ndi.map_coordinates, - output=output_dtype, + _sdc_unwarp, + voxshift=voxelshift_map, + pe_axis=pe_axis, + output_dtype=output_dtype, order=order, mode=mode, cval=cval, @@ -85,8 +120,11 @@ async def unwarp_parallel( # Create a worker task for each chunk tasks = [] - for volume in np.rollaxis(fulldataset, 3, 0): - task = asyncio.create_task(worker(volume, coordinates, func, semaphore)) + for volid, volume in enumerate(np.rollaxis(fulldataset, -1, 0)): + xfm = None if xfms is None else xfms[volid] + + # IMPORTANT - the coordinates array must be copied every time anew per thread + task = asyncio.create_task(worker(volume, coordinates.copy(), xfm, func, semaphore)) tasks.append(task) # Wait for all tasks to complete @@ -275,6 +313,9 @@ def apply( if isinstance(moving, (str, bytes, Path)): moving = nb.load(moving) + # Generate warp field (before ensuring positive cosines) + self.fit(moving) + # Make sure the data array has all cosines positive (i.e., no axes are flipped) moving, axcodes = ensure_positive_cosines(moving) @@ -286,9 +327,6 @@ def apply( if axis_flip ^ pe_flip: ro_time *= -1.0 - # Generate warp field - self.fit(moving) - # Prepare data data = np.squeeze(np.asanyarray(moving.dataobj)) output_dtype = output_dtype or moving.header.get_data_dtype() @@ -300,7 +338,13 @@ def apply( # The VSM is just the displacements field given in index coordinates # voxcoords is the deformation field, i.e., the target position of each voxel - vsm = self.mapped.get_fdata(dtype="float32").copy() * ro_time + vsm = self.mapped.get_fdata(dtype="float32") * ro_time + + # Convert head-motion transforms to voxel-to-voxel: + if xfms is not None: + vox2ras = moving.affine.copy() + ras2vox = np.linalg.inv(vox2ras) + xfms = [ras2vox @ xfm @ vox2ras for xfm in xfms] # Resample resampled = asyncio.run(unwarp_parallel(