From a75ba6a698b4a1d50b9afe1146b59f6652aff882 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Sat, 4 May 2024 10:49:37 -0400 Subject: [PATCH 1/9] Work on incorporating resampler. --- pyproject.toml | 9 +- src/fmripost_aroma/interfaces/resampler.py | 64 ++ src/fmripost_aroma/utils/resampler.py | 785 +++++++++++++++++++++ 3 files changed, 854 insertions(+), 4 deletions(-) create mode 100644 src/fmripost_aroma/interfaces/resampler.py create mode 100644 src/fmripost_aroma/utils/resampler.py diff --git a/pyproject.toml b/pyproject.toml index b8e87ac..bbd4092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,13 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "fmriprep", + "fmriprep @ git+https://github.com/nipreps/fmriprep.git@master", "nipype >= 1.8.5", - "nireports", - "niworkflows", + "nireports @ git+https://github.com/nipreps/nireports.git@main", + "niworkflows @ git+https://github.com/nipreps/niworkflows.git@master", "pybids >= 0.15.6", - "smriprep", + "sdcflows @ git+https://github.com/nipreps/sdcflows.git@master", + "smriprep @ git+https://github.com/nipreps/smriprep.git@master", "typer", ] dynamic = ["version"] diff --git a/src/fmripost_aroma/interfaces/resampler.py b/src/fmripost_aroma/interfaces/resampler.py new file mode 100644 index 0000000..7913064 --- /dev/null +++ b/src/fmripost_aroma/interfaces/resampler.py @@ -0,0 +1,64 @@ +"""Interfaces for resampling.""" + +from nipype.interfaces.base import ( + BaseInterfaceInputSpec, + File, + InputMultiObject, + SimpleInterface, + TraitedSpec, + isdefined, + traits, +) + + +class _ResamplerInputSpec(BaseInterfaceInputSpec): + bold_file = File(exists=True, desc="BOLD file to resample.") + derivs_path = traits.Directory( + exists=True, + desc="Path to derivatives.", + ) + output_dir = traits.Directory( + exists=True, + desc="Output directory.", + ) + space = traits.Str( + "MNI152NLin6Asym", + usedefault=True, + desc="Output space.", + ) + resolution = traits.Str( + "2", + usedefault=True, + desc="Output resolution.", + ) + + +class _ResamplerOutputSpec(TraitedSpec): + output_file = File(exists=True, desc="Resampled BOLD file.") + + +class Resampler(SimpleInterface): + """Extract timeseries and compute connectivity matrices. + + Write out time series using Nilearn's NiftiLabelMasker + Then write out functional correlation matrix of + timeseries using numpy. + """ + + input_spec = _ResamplerInputSpec + output_spec = _ResamplerOutputSpec + + def _run_interface(self, runtime): + from fmripost_aroma.utils import resampler + + output_file = resampler.main( + bold_file=self.inputs.bold_file, + derivs_path=self.inputs.derivs_path, + output_dir=self.inputs.output_dir, + space=self.inputs.space, + resolution=self.inputs.resolution, + ) + + self._results["output_file"] = output_file + + return runtime diff --git a/src/fmripost_aroma/utils/resampler.py b/src/fmripost_aroma/utils/resampler.py new file mode 100644 index 0000000..00dcc28 --- /dev/null +++ b/src/fmripost_aroma/utils/resampler.py @@ -0,0 +1,785 @@ +import asyncio +import os +from functools import partial +from pathlib import Path +from typing import Callable, TypeVar + +import h5py +import nibabel as nb +import nitransforms as nt +import niworkflows.data +import numpy as np +import typer +from bids import BIDSLayout +from nitransforms.io.itk import ITKCompositeH5 +from scipy import ndimage as ndi +from scipy.sparse import hstack as sparse_hstack +from sdcflows.transform import grid_bspline_weights +from sdcflows.utils.tools import ensure_positive_cosines +from templateflow import api as tf +from typing_extensions import Annotated + + +R = TypeVar('R') + +nipreps_cfg = niworkflows.data.load('nipreps.json') + + +def find_bids_root(path: Path) -> Path: + for parent in path.parents: + if Path.exists(parent / 'dataset_description.json'): + return parent + raise ValueError(f'Cannot detect BIDS dataset containing {path}') + + +def resample_vol( + data: np.ndarray, + coordinates: np.ndarray, + pe_info: tuple[int, float], + hmc_xfm: np.ndarray | None, + fmap_hz: np.ndarray, + output: np.dtype | np.ndarray | None = None, + order: int = 3, + mode: str = 'constant', + cval: float = 0.0, + prefilter: bool = True, +) -> np.ndarray: + """Resample a volume at specified coordinates + + This function implements simultaneous head-motion correction and + susceptibility-distortion correction. It accepts coordinates in + the source voxel space. It is the responsibility of the caller to + transform coordinates from any other target space. + + Parameters + ---------- + data + The data array to resample + coordinates + The first-approximation voxel coordinates to sample from ``data`` + The first dimension should have length ``data.ndim``. The further + dimensions have the shape of the target array. + pe_info + The readout vector in the form of (axis, signed-readout-time) + ``(1, -0.04)`` becomes ``[0, -0.04, 0]``, which indicates that a + +1 Hz deflection in the field shifts 0.04 voxels toward the start + of the data array in the second dimension. + hmc_xfm + Affine transformation accounting for head motion from the individual + volume into the BOLD reference space. This affine must be in VOX2VOX + form. + fmap_hz + The fieldmap, sampled to the target space, in Hz + output + The dtype or a pre-allocated array for sampling into the target space. + If pre-allocated, ``output.shape == coordinates.shape[1:]``. + order + Order of interpolation (default: 3 = cubic) + mode + How ``data`` is extended beyond its boundaries. See + :func:`scipy.ndimage.map_coordinates` for more details. + cval + Value to fill past edges of ``data`` if ``mode`` is ``'constant'``. + prefilter + Determines if ``data`` is pre-filtered before interpolation. + + Returns + ------- + resampled_array + The resampled array, with shape ``coordinates.shape[1:]``. + """ + if hmc_xfm is not None: + # Move image with the head + coords_shape = coordinates.shape + coordinates = nb.affines.apply_affine( + hmc_xfm, coordinates.reshape(coords_shape[0], -1).T + ).T.reshape(coords_shape) + else: + # Copy coordinates to avoid interfering with other calls + coordinates = coordinates.copy() + + vsm = fmap_hz * pe_info[1] + coordinates[pe_info[0], ...] += vsm + + jacobian = 1 + np.gradient(vsm, axis=pe_info[0]) + + result = ndi.map_coordinates( + data, + coordinates, + output=output, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + result *= jacobian + return result + + +async def worker(job: Callable[[], R], semaphore) -> R: + async with semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, job) + + +async def resample_series_async( + data: np.ndarray, + coordinates: np.ndarray, + pe_info: list[tuple[int, float]], + hmc_xfms: list[np.ndarray] | None, + fmap_hz: np.ndarray, + output_dtype: np.dtype | None = None, + order: int = 3, + mode: str = 'constant', + cval: float = 0.0, + prefilter: bool = True, + max_concurrent: int = min(os.cpu_count(), 12), +) -> np.ndarray: + """Resample a 4D time series at specified coordinates + + This function implements simultaneous head-motion correction and + susceptibility-distortion correction. It accepts coordinates in + the source voxel space. It is the responsibility of the caller to + transform coordinates from any other target space. + + Parameters + ---------- + data + The data array to resample + coordinates + The first-approximation voxel coordinates to sample from ``data``. + The first dimension should have length 3. + The further dimensions determine the shape of the target array. + pe_info + A list of readout vectors in the form of (axis, signed-readout-time) + ``(1, -0.04)`` becomes ``[0, -0.04, 0]``, which indicates that a + +1 Hz deflection in the field shifts 0.04 voxels toward the start + of the data array in the second dimension. + hmc_xfm + A sequence of affine transformations accounting for head motion from + the individual volume into the BOLD reference space. + These affines must be in VOX2VOX form. + fmap_hz + The fieldmap, sampled to the target space, in Hz + output_dtype + The dtype of the output array. + order + Order of interpolation (default: 3 = cubic) + mode + How ``data`` is extended beyond its boundaries. See + :func:`scipy.ndimage.map_coordinates` for more details. + cval + Value to fill past edges of ``data`` if ``mode`` is ``'constant'``. + prefilter + Determines if ``data`` is pre-filtered before interpolation. + max_concurrent + Maximum number of volumes to resample concurrently + + Returns + ------- + resampled_array + The resampled array, with shape ``coordinates.shape[1:] + (N,)``, + where N is the number of volumes in ``data``. + """ + if data.ndim == 3: + return resample_vol( + data, + coordinates, + pe_info[0], + hmc_xfms[0] if hmc_xfms else None, + fmap_hz, + output_dtype, + order, + mode, + cval, + prefilter, + ) + + semaphore = asyncio.Semaphore(max_concurrent) + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype, order='F') + + tasks = [ + asyncio.create_task( + worker( + partial( + resample_vol, + data=volume, + coordinates=coordinates, + pe_info=pe_info[volid], + hmc_xfm=hmc_xfms[volid] if hmc_xfms else None, + fmap_hz=fmap_hz, + output=out_array[..., volid], + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ), + semaphore, + ) + ) + for volid, volume in enumerate(np.rollaxis(data, -1, 0)) + ] + + await asyncio.gather(*tasks) + + return out_array + + +def resample_series( + data: np.ndarray, + coordinates: np.ndarray, + pe_info: list[tuple[int, float]], + hmc_xfms: list[np.ndarray] | None, + fmap_hz: np.ndarray, + output_dtype: np.dtype | None = None, + order: int = 3, + mode: str = 'constant', + cval: float = 0.0, + prefilter: bool = True, + nthreads: int = 1, +) -> np.ndarray: + """Resample a 4D time series at specified coordinates + + This function implements simultaneous head-motion correction and + susceptibility-distortion correction. It accepts coordinates in + the source voxel space. It is the responsibility of the caller to + transform coordinates from any other target space. + + Parameters + ---------- + data + The data array to resample + coordinates + The first-approximation voxel coordinates to sample from ``data``. + The first dimension should have length 3. + The further dimensions determine the shape of the target array. + pe_info + A list of readout vectors in the form of (axis, signed-readout-time) + ``(1, -0.04)`` becomes ``[0, -0.04, 0]``, which indicates that a + +1 Hz deflection in the field shifts 0.04 voxels toward the start + of the data array in the second dimension. + hmc_xfm + A sequence of affine transformations accounting for head motion from + the individual volume into the BOLD reference space. + These affines must be in VOX2VOX form. + fmap_hz + The fieldmap, sampled to the target space, in Hz + output_dtype + The dtype of the output array. + order + Order of interpolation (default: 3 = cubic) + mode + How ``data`` is extended beyond its boundaries. See + :func:`scipy.ndimage.map_coordinates` for more details. + cval + Value to fill past edges of ``data`` if ``mode`` is ``'constant'``. + prefilter + Determines if ``data`` is pre-filtered before interpolation. + + Returns + ------- + resampled_array + The resampled array, with shape ``coordinates.shape[1:] + (N,)``, + where N is the number of volumes in ``data``. + """ + return asyncio.run( + resample_series_async( + data=data, + coordinates=coordinates, + pe_info=pe_info, + hmc_xfms=hmc_xfms, + fmap_hz=fmap_hz, + output_dtype=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + max_concurrent=nthreads, + ) + ) + + +def parse_combined_hdf5(h5_fn, to_ras=True): + # Borrowed from https://github.com/feilong/process + # process.resample.parse_combined_hdf5() + h = h5py.File(h5_fn) + xform = ITKCompositeH5.from_h5obj(h) + affine = xform[0].to_ras() + # Confirm these transformations are applicable + assert ( + h['TransformGroup']['2']['TransformType'][:][0] + == b'DisplacementFieldTransform_float_3_3' + ) + assert np.array_equal( + h['TransformGroup']['2']['TransformFixedParameters'][:], + np.array( + [ + 193.0, + 229.0, + 193.0, + 96.0, + 132.0, + -78.0, + 1.0, + 1.0, + 1.0, + -1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + 1.0, + ] + ), + ) + warp = h['TransformGroup']['2']['TransformParameters'][:] + warp = warp.reshape((193, 229, 193, 3)).transpose(2, 1, 0, 3) + warp *= np.array([-1, -1, 1]) + warp_affine = np.array( + [ + [1.0, 0.0, 0.0, -96.0], + [0.0, 1.0, 0.0, -132.0], + [0.0, 0.0, 1.0, -78.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + return affine, warp, warp_affine + + +def load_ants_h5(filename: Path) -> nt.TransformChain: + """Load ANTs H5 files as a nitransforms TransformChain""" + affine, warp, warp_affine = parse_combined_hdf5(filename) + warp_transform = nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)) + return nt.TransformChain([warp_transform, nt.Affine(affine)]) + + +def load_transforms(xfm_paths: list[Path]) -> nt.base.TransformBase: + """Load a series of transforms as a nitransforms TransformChain + + An empty list will return an identity transform + """ + chain = None + for path in xfm_paths[::-1]: + path = Path(path) + if path.suffix == '.h5': + xfm = load_ants_h5(path) + else: + xfm = nt.linear.load(path) + if chain is None: + chain = xfm + else: + chain += xfm + if chain is None: + chain = nt.base.TransformBase() + return chain + + +def aligned(aff1: np.ndarray, aff2: np.ndarray) -> bool: + """Determine if two affines have aligned grids""" + return np.allclose( + np.linalg.norm(np.cross(aff1[:-1, :-1].T, aff2[:-1, :-1].T), axis=1), + 0, + atol=1e-3, + ) + + +def as_affine(xfm: nt.base.TransformBase) -> nt.Affine | None: + # Identity transform + if type(xfm) is nt.base.TransformBase: + return nt.Affine() + + if isinstance(xfm, nt.Affine): + return xfm + + if isinstance(xfm, nt.TransformChain) and all( + isinstance(x, nt.Affine) for x in xfm + ): + return xfm.asaffine() + + return None + + +def resample_fieldmap( + coefficients: list[nb.Nifti1Image], + fmap_reference: nb.Nifti1Image, + target: nb.Nifti1Image, + transforms: nt.TransformChain, +) -> nb.Nifti1Image: + """Resample a fieldmap from B-Spline coefficients into a target space + + If the coefficients and target are aligned, the field is reconstructed + directly in the target space. + If not, then the field is reconstructed to the ``fmap_reference`` + resolution, and then resampled according to transforms. + + The former method only applies if the transform chain can be + collapsed to a single affine transform. + + Parameters + ---------- + coefficients + list of B-spline coefficient files. The affine matrices are used + to reconstruct the knot locations. + fmap_reference + The intermediate reference to reconstruct the fieldmap in, if + it cannot be reconstructed directly in the target space. + target + The target space to to resample the fieldmap into. + transforms + A nitransforms TransformChain that maps images from the fieldmap + space into the target space. + + Returns + ------- + fieldmap + The fieldmap encoded in ``coefficients``, resampled in the same + space as ``target`` + """ + + direct = False + affine_xfm = as_affine(transforms) + if affine_xfm is not None: + # Transforms maps RAS coordinates in the target to RAS coordinates in + # the fieldmap space. Composed with target.affine, we have a target voxel + # to fieldmap RAS affine. Hence, this is projected into fieldmap space. + projected_affine = affine_xfm.matrix @ target.affine + # If the coordinates have the same rotation from voxels, we can construct + # bspline weights efficiently. + direct = aligned(projected_affine, coefficients[-1].affine) + + if direct: + reference, _ = ensure_positive_cosines( + target.__class__(target.dataobj, projected_affine, target.header), + ) + else: + if not aligned(fmap_reference.affine, coefficients[-1].affine): + raise ValueError( + 'Reference passed is not aligned with spline grids' + ) + reference, _ = ensure_positive_cosines(fmap_reference) + + # Generate tensor-product B-Spline weights + colmat = sparse_hstack( + [grid_bspline_weights(reference, level) for level in coefficients] + ).tocsr() + coefficients = np.hstack( + [ + level.get_fdata(dtype='float32').reshape(-1) + for level in coefficients + ] + ) + + # Reconstruct the fieldmap (in Hz) from coefficients + fmap_img = nb.Nifti1Image( + np.reshape(colmat @ coefficients, reference.shape[:3]), + reference.affine, + ) + + if not direct: + fmap_img = transforms.apply(fmap_img, reference=target) + + fmap_img.header.set_intent('estimate', name='fieldmap Hz') + fmap_img.header.set_data_dtype('float32') + fmap_img.header['cal_max'] = max( + (abs(fmap_img.dataobj.min()), fmap_img.dataobj.max()) + ) + fmap_img.header['cal_min'] = -fmap_img.header['cal_max'] + + return fmap_img + + +def resample_bold( + source: nb.Nifti1Image, + target: nb.Nifti1Image, + transforms: nt.TransformChain, + fieldmap: nb.Nifti1Image | None, + pe_info: list[tuple[int, float]] | None, + nthreads: int = 1, +) -> nb.Nifti1Image: + """Resample a 4D bold series into a target space, applying head-motion + and susceptibility-distortion correction simultaneously. + + Parameters + ---------- + source + The 4D bold series to resample. + target + An image sampled in the target space. + transforms + A nitransforms TransformChain that maps images from the individual + BOLD volume space into the target space. + fieldmap + The fieldmap, in Hz, sampled in the target space + pe_info + A list of readout vectors in the form of (axis, signed-readout-time) + ``(1, -0.04)`` becomes ``[0, -0.04, 0]``, which indicates that a + +1 Hz deflection in the field shifts 0.04 voxels toward the start + of the data array in the second dimension. + nthreads + Number of threads to use for parallel resampling + + Returns + ------- + resampled_bold + The BOLD series resampled into the target space + """ + # HMC goes last + assert isinstance(transforms[-1], nt.linear.LinearTransformsMapping) + + # Retrieve the RAS coordinates of the target space + coordinates = ( + nt.base.SpatialReference.factory(target).ndcoords.astype('f4').T + ) + + # We will operate in voxel space, so get the source affine + vox2ras = source.affine + ras2vox = np.linalg.inv(vox2ras) + # Transform RAS2RAS head motion transforms to VOX2VOX + hmc_xfms = [ras2vox @ xfm.matrix @ vox2ras for xfm in transforms[-1]] + + # Remove the head-motion transforms and add a mapping from boldref + # world space to voxels. This new transform maps from world coordinates + # in the target space to voxel coordinates in the source space. + ref2vox = nt.TransformChain(transforms[:-1] + [nt.Affine(ras2vox)]) + mapped_coordinates = ref2vox.map(coordinates) + + # Some identities to reduce special casing downstream + if fieldmap is None: + fieldmap = nb.Nifti1Image( + np.zeros(target.shape[:3], dtype='f4'), target.affine + ) + if pe_info is None: + pe_info = [[0, 0] for _ in range(source.shape[-1])] + + resampled_data = resample_series( + data=source.get_fdata(dtype='f4'), + coordinates=mapped_coordinates.T.reshape((3, *target.shape[:3])), + pe_info=pe_info, + hmc_xfms=hmc_xfms, + fmap_hz=fieldmap.get_fdata(dtype='f4'), + output_dtype='f4', + nthreads=nthreads, + ) + resampled_img = nb.Nifti1Image( + resampled_data, target.affine, target.header + ) + resampled_img.set_data_dtype('f4') + + return resampled_img + + +def genref( + source_img: nb.Nifti1Image, + target_zooms: float | tuple[float, float, float], +) -> nb.Nifti1Image: + """Create a reference image with target voxel sizes, preserving + the original field of view + """ + factor = np.array(target_zooms) / source_img.header.get_zooms()[:3] + # Generally round up to the nearest voxel, but not for slivers of voxels + target_shape = np.ceil(np.array(source_img.shape[:3]) / factor - 0.01) + target_affine = nb.affines.rescale_affine( + source_img.affine, source_img.shape, target_zooms, target_shape + ) + return nb.Nifti1Image( + nb.fileslice.strided_scalar(target_shape.astype(int)), + target_affine, + source_img.header, + ) + + +def mkents(source, target, **entities): + """Helper to create entity query for transforms""" + return {'from': source, 'to': target, 'suffix': 'xfm', **entities} + + +def main( + bold_file: Path, + derivs_path: Path, + output_dir: Path, + space: Annotated[str, typer.Option(help='Target space to resample to')], + resolution: Annotated[str, typer.Option(help='Target resolution')] = None, + nthreads: Annotated[ + int, + typer.Option(help='Number of resampling threads (0 for all cores)'), + ] = 1, +): + """Resample a bold file to a target space using the transforms found + in a derivatives directory. + """ + bids_root = find_bids_root(bold_file) + raw = BIDSLayout(bids_root) + derivs = BIDSLayout(derivs_path, config=[nipreps_cfg], validate=False) + + if resolution is not None: + zooms = tuple(int(dim) for dim in resolution.split('x')) + if len(zooms) not in (1, 3): + raise ValueError(f'Unknown resolution: {resolution}') + + cpu_count = os.cpu_count() + if nthreads < 1: + nthreads = cpu_count + elif nthreads > cpu_count: + print(f'Warning: More threads requested ({nthreads}) than cores ({cpu_count})') + + bold = raw.files[str(bold_file)] + bold_meta = bold.get_metadata() + entities = bold.get_entities() + entities.pop('datatype') + entities.pop('extension') + entities.pop('suffix') + + bold_xfms = [] + fmap_xfms = [] + + try: + hmc = derivs.get( + extension='.txt', **mkents('orig', 'boldref', **entities) + )[0] + except IndexError: + raise ValueError('Could not find HMC transforms') + + bold_xfms.append(hmc) + + if space == 'boldref': + reference = derivs.get( + desc='coreg', suffix='boldref', extension='.nii.gz', **entities + )[0] + else: + try: + coreg = derivs.get( + extension='.txt', **mkents('boldref', 'T1w', **entities) + )[0] + except IndexError: + raise ValueError('Could not find coregistration transform') + + bold_xfms.append(coreg) + fmap_xfms.append(coreg) + + if space in ('anat', 'T1w'): + reference = derivs.get( + subject=entities['subject'], + desc='preproc', + suffix='T1w', + extension='.nii.gz', + )[0] + if resolution is not None: + ref_img = genref(nb.load(reference), zooms) + elif space not in ('anat', 'boldref', 'T1w'): + try: + template_reg = derivs.get( + datatype='anat', + extension='.h5', + subject=entities['subject'], + **mkents('T1w', space), + )[0] + except IndexError: + raise ValueError( + f'Could not find template registration for {space}' + ) + + bold_xfms.append(template_reg) + fmap_xfms.append(template_reg) + + # Get mask, as shape/affine is all we need + reference = tf.get( + template=space, + extension='.nii.gz', + desc='brain', + suffix='mask', + resolution=resolution, + ) + if not reference: + # Get a hires image to resample + reference = tf.get( + template=space, + extension='.nii.gz', + desc='brain', + suffix='mask', + resolution='1', + ) + ref_img = genref(nb.load(reference), zooms) + + fmapregs = derivs.get( + extension='.txt', **mkents('boldref', derivs.get_fmapids(), **entities) + ) + if not fmapregs: + print('No fieldmap registrations found') + elif len(fmapregs) > 1: + raise ValueError( + f'Found fieldmap registrations: {fmapregs}\nPass one as an argument.' + ) + + fieldmap = None + if fmapregs: + fmapreg = fmapregs[0] + fmapid = fmapregs[0].entities['to'] + fieldmap_coeffs = derivs.get( + fmapid=fmapid, + desc=['coeff', 'coeff0', 'coeff1'], + extension='.nii.gz', + ) + fmapref = derivs.get( + fmapid=fmapid, + desc='preproc', + extension='.nii.gz', + )[0] + transforms = load_transforms(fmap_xfms) + # We get an inverse transform, so need to add it separately + fmap_xfms.insert(0, fmapreg) + transforms += ~nt.linear.load(Path(fmapreg)) + print(transforms.transforms) + + print(f'Resampling fieldmap {fmapid} into {space}:{resolution}') + print('Coefficients:') + print('\n'.join(f'\t{Path(c).name}' for c in fieldmap_coeffs)) + print(f'Reference: {Path(reference).name}') + print('Transforms:') + print('\n'.join(f'\t{Path(xfm).name}' for xfm in fmap_xfms)) + fieldmap = resample_fieldmap( + coefficients=[nb.load(coeff) for coeff in fieldmap_coeffs], + fmap_reference=nb.load(fmapref), + target=ref_img, + transforms=transforms, + ) + fieldmap.to_filename(output_dir / f'{fmapid}.nii.gz') + + pe_dir = bold_meta['PhaseEncodingDirection'] + ro_time = bold_meta['TotalReadoutTime'] + pe_axis = 'ijk'.index(pe_dir[0]) + pe_flip = pe_dir.endswith('-') + + bold_img = nb.load(bold_file) + source, axcodes = ensure_positive_cosines(bold_img) + axis_flip = axcodes[pe_axis] in 'LPI' + + pe_info = (pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time) + + if ref_img is None: + ref_img = nb.load(reference) + + print() + print(f'Resampling BOLD {bold_file.name} ({pe_info})') + print(f'Reference: {Path(reference).name}') + print('Transforms:') + print('\n'.join(f'\t{Path(xfm).name}' for xfm in bold_xfms)) + output_file = output_dir / bold_file.name + resample_bold( + source=source, + target=ref_img, + transforms=load_transforms(bold_xfms), + fieldmap=fieldmap, + pe_info=[pe_info for _ in range(source.shape[-1])], + nthreads=nthreads, + ).to_filename(output_file) + return output_file + + +if __name__ == '__main__': + typer.run(main) From 259356b94487b9a5fb4c9eea83632fb840f405c0 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Sat, 4 May 2024 11:01:39 -0400 Subject: [PATCH 2/9] Something like this? --- src/fmripost_aroma/workflows/base.py | 44 ++++++++++++++++------------ 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index 8e4d5de..874200e 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -318,6 +318,9 @@ def init_single_subject_wf(subject_id: str): """ for bold_file in subject_data["bold"]: + ica_aroma_wf = init_ica_aroma_wf(bold_file=bold_file) + ica_aroma_wf.__desc__ = func_pre_desc + (ica_aroma_wf.__desc__ or "") + functional_cache = {} if config.execution.derivatives: # Collect native-space derivatives and warp them to MNI152NLin6Asym @@ -332,30 +335,33 @@ def init_single_subject_wf(subject_id: str): entities=entities, ) ) + + resample_raw_wf = init_resample_raw_wf( + bold_file=bold_file, + precomputed=functional_cache, + ) + workflow.connect([ + (resample_raw_wf, ica_aroma_wf, [ + ("outputnode.bold_std", "inputnode.bold_std"), + ("outputnode.bold_mask_std", "inputnode.bold_mask_std"), + ]), + ]) # fmt:skip else: # Collect standard-space derivatives from fmripost_aroma.utils.bids import collect_derivatives - ... - - ica_aroma_wf = init_ica_aroma_wf( - bold_file=bold_file, - precomputed=functional_cache, - ) - ica_aroma_wf.__desc__ = func_pre_desc + (ica_aroma_wf.__desc__ or "") + functional_cache.update( + collect_derivatives( + derivatives_dir=deriv_dir, + entities=entities, + ) + ) + ica_aroma_wf.inputs.inputnode.bold_std = functional_cache["bold_std"] + ica_aroma_wf.inputs.inputnode.bold_mask_std = functional_cache["bold_mask_std"] - # fmt:off - workflow.connect([ - (inputnode, ica_aroma_wf, [ - ('bold_std', 'inputnode.bold_std'), - ("bold_mask_std", "inputnode.bold_mask_std"), - ("movpar_file", "inputnode.movpar_file"), - ("name_source", "inputnode.name_source"), - ("skip_vols", "inputnode.skip_vols"), - ("spatial_reference", "inputnode.spatial_reference"), - ]), - ]) - # fmt:on + ica_aroma_wf.inputs.inputnode.movpar_file = functional_cache["movpar_file"] + ica_aroma_wf.inputs.inputnode.skip_vols = functional_cache["skip_vols"] + ica_aroma_wf.inputs.inputnode.spatial_reference = functional_cache["spatial_reference"] return clean_datasinks(workflow) From 4994a106525111f83f989a426e49e20f8bafddcc Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Sat, 4 May 2024 11:13:00 -0400 Subject: [PATCH 3/9] Keep working. --- src/fmripost_aroma/interfaces/resampler.py | 2 - src/fmripost_aroma/workflows/base.py | 4 +- src/fmripost_aroma/workflows/resampling.py | 54 ++++++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 src/fmripost_aroma/workflows/resampling.py diff --git a/src/fmripost_aroma/interfaces/resampler.py b/src/fmripost_aroma/interfaces/resampler.py index 7913064..61a30f8 100644 --- a/src/fmripost_aroma/interfaces/resampler.py +++ b/src/fmripost_aroma/interfaces/resampler.py @@ -3,10 +3,8 @@ from nipype.interfaces.base import ( BaseInterfaceInputSpec, File, - InputMultiObject, SimpleInterface, TraitedSpec, - isdefined, traits, ) diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index 874200e..c6b228c 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -40,6 +40,7 @@ from fmripost_aroma import config from fmripost_aroma.interfaces.bids import DerivativesDataSink from fmripost_aroma.interfaces.reportlets import AboutSummary, SubjectSummary +from fmripost_aroma.workflows.resampling import init_resample_raw_wf def init_fmripost_aroma_wf(): @@ -256,7 +257,8 @@ def init_single_subject_wf(subject_id: str): ) bids_info = pe.Node( - BIDSInfo(bids_dir=config.execution.bids_dir, bids_validate=False), name="bids_info" + BIDSInfo(bids_dir=config.execution.bids_dir, bids_validate=False), + name="bids_info", ) summary = pe.Node( diff --git a/src/fmripost_aroma/workflows/resampling.py b/src/fmripost_aroma/workflows/resampling.py new file mode 100644 index 0000000..15d00ad --- /dev/null +++ b/src/fmripost_aroma/workflows/resampling.py @@ -0,0 +1,54 @@ +"""Workflows to resample data.""" + +from nipype.interfaces import utility as niu +from nipype.pipeline import engine as pe + + +def init_resample_raw_wf(bold_file, functional_cache): + """Resample raw BOLD data to MNI152NLin6Asym:res-2mm space.""" + from fmriprep.workflows.bold.stc import init_bold_stc_wf + from niworkflows.engine.workflows import LiterateWorkflow as Workflow + + from fmripost_aroma.interfaces.resampler import Resampler + + workflow = Workflow(name="resample_raw_wf") + + inputnode = pe.Node( + niu.IdentityInterface(fields=["bold_file", "mask_file"]), + name="inputnode", + ) + inputnode.inputs.bold_file = bold_file + inputnode.inputs.mask_file = functional_cache["bold_mask"] + + outputnode = pe.Node( + niu.IdentityInterface(fields=["bold_std", "bold_mask_std"]), + name="outputnode", + ) + + stc_wf = init_bold_stc_wf(name="resample_stc_wf") + workflow.connect([ + (inputnode, stc_wf, [ + ('bold_file', 'inputnode.bold_file'), + ('mask_file', 'inputnode.mask_file'), + ]), + ]) # fmt:skip + + resample_bold = pe.Node( + Resampler(space="MNI152NLin6Asym", resolution="2"), + name="resample_bold", + ) + workflow.connect([ + (stc_wf, resample_bold, [('outputnode.bold_file', 'bold_file')]), + (resample_bold, outputnode, [('output_file', 'bold_std')]), + ]) # fmt:skip + + resample_bold_mask = pe.Node( + Resampler(space="MNI152NLin6Asym", resolution="2"), + name="resample_bold_mask", + ) + workflow.connect([ + (inputnode, resample_bold_mask, [('mask_file', 'bold_file')]), + (resample_bold_mask, outputnode, [('output_file', 'bold_mask_std')]), + ]) # fmt:skip + + return workflow From 328a1d02f35097cb364814035e3a412be0a8c140 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Sat, 4 May 2024 11:15:16 -0400 Subject: [PATCH 4/9] Update base.py --- src/fmripost_aroma/workflows/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index c6b228c..7791213 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -360,11 +360,16 @@ def init_single_subject_wf(subject_id: str): ) ica_aroma_wf.inputs.inputnode.bold_std = functional_cache["bold_std"] ica_aroma_wf.inputs.inputnode.bold_mask_std = functional_cache["bold_mask_std"] + workflow.add_nodes([ica_aroma_wf]) ica_aroma_wf.inputs.inputnode.movpar_file = functional_cache["movpar_file"] ica_aroma_wf.inputs.inputnode.skip_vols = functional_cache["skip_vols"] ica_aroma_wf.inputs.inputnode.spatial_reference = functional_cache["spatial_reference"] + # Now denoise the native-space BOLD data using ICA-AROMA + + # Now warp the denoised BOLD data to the requested output spaces + return clean_datasinks(workflow) From 5617d38e6f8e42af5712ec2d945dc65a64a1e2a6 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 7 May 2024 15:50:33 -0400 Subject: [PATCH 5/9] Drop task_id from collect_derivatives. --- src/fmripost_aroma/utils/bids.py | 2 +- src/fmripost_aroma/workflows/base.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fmripost_aroma/utils/bids.py b/src/fmripost_aroma/utils/bids.py index 7ca2369..cebc7d6 100644 --- a/src/fmripost_aroma/utils/bids.py +++ b/src/fmripost_aroma/utils/bids.py @@ -76,7 +76,7 @@ def collect_derivatives_old( """Collect preprocessing derivatives.""" subj_data = { "bold_raw": "", - "" "bold_boldref": "", + "bold_boldref": "", "bold_MNI152NLin6": "", } query = { diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index b5bf345..31c2fa0 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -219,8 +219,7 @@ def init_single_subject_wf(subject_id: str): subject_data = collect_derivatives( config.execution.layout, subject_id, - task=config.execution.task_id, - bids_filters=config.execution.bids_filters, + entities=config.execution.bids_filters, ) if "flair" in config.workflow.ignore: From d267a864b72535ff628bffab10221e83bc4042fa Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 7 May 2024 15:57:12 -0400 Subject: [PATCH 6/9] Run ruff. --- src/fmripost_aroma/interfaces/resampler.py | 18 +++++++++--------- src/fmripost_aroma/utils/resampler.py | 5 ++++- src/fmripost_aroma/workflows/base.py | 16 ++++++++-------- src/fmripost_aroma/workflows/resampling.py | 22 +++++++++++----------- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/fmripost_aroma/interfaces/resampler.py b/src/fmripost_aroma/interfaces/resampler.py index 61a30f8..cf56900 100644 --- a/src/fmripost_aroma/interfaces/resampler.py +++ b/src/fmripost_aroma/interfaces/resampler.py @@ -10,29 +10,29 @@ class _ResamplerInputSpec(BaseInterfaceInputSpec): - bold_file = File(exists=True, desc="BOLD file to resample.") + bold_file = File(exists=True, desc='BOLD file to resample.') derivs_path = traits.Directory( exists=True, - desc="Path to derivatives.", + desc='Path to derivatives.', ) output_dir = traits.Directory( exists=True, - desc="Output directory.", + desc='Output directory.', ) space = traits.Str( - "MNI152NLin6Asym", + 'MNI152NLin6Asym', usedefault=True, - desc="Output space.", + desc='Output space.', ) resolution = traits.Str( - "2", + '2', usedefault=True, - desc="Output resolution.", + desc='Output resolution.', ) class _ResamplerOutputSpec(TraitedSpec): - output_file = File(exists=True, desc="Resampled BOLD file.") + output_file = File(exists=True, desc='Resampled BOLD file.') class Resampler(SimpleInterface): @@ -57,6 +57,6 @@ def _run_interface(self, runtime): resolution=self.inputs.resolution, ) - self._results["output_file"] = output_file + self._results['output_file'] = output_file return runtime diff --git a/src/fmripost_aroma/utils/resampler.py b/src/fmripost_aroma/utils/resampler.py index 00dcc28..dcc5ad3 100644 --- a/src/fmripost_aroma/utils/resampler.py +++ b/src/fmripost_aroma/utils/resampler.py @@ -1,3 +1,7 @@ +"""Resampler methods for fMRI data.""" + +from __future__ import annotations + import asyncio import os from functools import partial @@ -19,7 +23,6 @@ from templateflow import api as tf from typing_extensions import Annotated - R = TypeVar('R') nipreps_cfg = niworkflows.data.load('nipreps.json') diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index 2c01975..76b8948 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -321,7 +321,7 @@ def init_single_subject_wf(subject_id: str): for bold_file in subject_data['bold']: ica_aroma_wf = init_ica_aroma_wf(bold_file=bold_file) - ica_aroma_wf.__desc__ = func_pre_desc + (ica_aroma_wf.__desc__ or "") + ica_aroma_wf.__desc__ = func_pre_desc + (ica_aroma_wf.__desc__ or '') functional_cache = {} if config.execution.derivatives: @@ -344,8 +344,8 @@ def init_single_subject_wf(subject_id: str): ) workflow.connect([ (resample_raw_wf, ica_aroma_wf, [ - ("outputnode.bold_std", "inputnode.bold_std"), - ("outputnode.bold_mask_std", "inputnode.bold_mask_std"), + ('outputnode.bold_std', 'inputnode.bold_std'), + ('outputnode.bold_mask_std', 'inputnode.bold_mask_std'), ]), ]) # fmt:skip else: @@ -358,13 +358,13 @@ def init_single_subject_wf(subject_id: str): entities=entities, ) ) - ica_aroma_wf.inputs.inputnode.bold_std = functional_cache["bold_std"] - ica_aroma_wf.inputs.inputnode.bold_mask_std = functional_cache["bold_mask_std"] + ica_aroma_wf.inputs.inputnode.bold_std = functional_cache['bold_std'] + ica_aroma_wf.inputs.inputnode.bold_mask_std = functional_cache['bold_mask_std'] workflow.add_nodes([ica_aroma_wf]) - ica_aroma_wf.inputs.inputnode.movpar_file = functional_cache["movpar_file"] - ica_aroma_wf.inputs.inputnode.skip_vols = functional_cache["skip_vols"] - ica_aroma_wf.inputs.inputnode.spatial_reference = functional_cache["spatial_reference"] + ica_aroma_wf.inputs.inputnode.movpar_file = functional_cache['movpar_file'] + ica_aroma_wf.inputs.inputnode.skip_vols = functional_cache['skip_vols'] + ica_aroma_wf.inputs.inputnode.spatial_reference = functional_cache['spatial_reference'] # Now denoise the native-space BOLD data using ICA-AROMA diff --git a/src/fmripost_aroma/workflows/resampling.py b/src/fmripost_aroma/workflows/resampling.py index 15d00ad..ae4753c 100644 --- a/src/fmripost_aroma/workflows/resampling.py +++ b/src/fmripost_aroma/workflows/resampling.py @@ -11,21 +11,21 @@ def init_resample_raw_wf(bold_file, functional_cache): from fmripost_aroma.interfaces.resampler import Resampler - workflow = Workflow(name="resample_raw_wf") + workflow = Workflow(name='resample_raw_wf') inputnode = pe.Node( - niu.IdentityInterface(fields=["bold_file", "mask_file"]), - name="inputnode", + niu.IdentityInterface(fields=['bold_file', 'mask_file']), + name='inputnode', ) inputnode.inputs.bold_file = bold_file - inputnode.inputs.mask_file = functional_cache["bold_mask"] + inputnode.inputs.mask_file = functional_cache['bold_mask'] outputnode = pe.Node( - niu.IdentityInterface(fields=["bold_std", "bold_mask_std"]), - name="outputnode", + niu.IdentityInterface(fields=['bold_std', 'bold_mask_std']), + name='outputnode', ) - stc_wf = init_bold_stc_wf(name="resample_stc_wf") + stc_wf = init_bold_stc_wf(name='resample_stc_wf') workflow.connect([ (inputnode, stc_wf, [ ('bold_file', 'inputnode.bold_file'), @@ -34,8 +34,8 @@ def init_resample_raw_wf(bold_file, functional_cache): ]) # fmt:skip resample_bold = pe.Node( - Resampler(space="MNI152NLin6Asym", resolution="2"), - name="resample_bold", + Resampler(space='MNI152NLin6Asym', resolution='2'), + name='resample_bold', ) workflow.connect([ (stc_wf, resample_bold, [('outputnode.bold_file', 'bold_file')]), @@ -43,8 +43,8 @@ def init_resample_raw_wf(bold_file, functional_cache): ]) # fmt:skip resample_bold_mask = pe.Node( - Resampler(space="MNI152NLin6Asym", resolution="2"), - name="resample_bold_mask", + Resampler(space='MNI152NLin6Asym', resolution='2'), + name='resample_bold_mask', ) workflow.connect([ (inputnode, resample_bold_mask, [('mask_file', 'bold_file')]), From 00d34402ee45d73cc0518e3a72cf8729852ab2fa Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 7 May 2024 16:05:16 -0400 Subject: [PATCH 7/9] Address ruff's concerns. --- src/fmripost_aroma/utils/resampler.py | 31 +++++++++++++++------------ 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/fmripost_aroma/utils/resampler.py b/src/fmripost_aroma/utils/resampler.py index dcc5ad3..af2b0b3 100644 --- a/src/fmripost_aroma/utils/resampler.py +++ b/src/fmripost_aroma/utils/resampler.py @@ -312,11 +312,13 @@ def parse_combined_hdf5(h5_fn, to_ras=True): xform = ITKCompositeH5.from_h5obj(h) affine = xform[0].to_ras() # Confirm these transformations are applicable - assert ( + if ( h['TransformGroup']['2']['TransformType'][:][0] - == b'DisplacementFieldTransform_float_3_3' - ) - assert np.array_equal( + != b'DisplacementFieldTransform_float_3_3' + ): + raise ValueError('Unsupported transform type') + + if not np.array_equal( h['TransformGroup']['2']['TransformFixedParameters'][:], np.array( [ @@ -340,7 +342,9 @@ def parse_combined_hdf5(h5_fn, to_ras=True): 1.0, ] ), - ) + ): + raise ValueError('Unsupported fixed parameters') + warp = h['TransformGroup']['2']['TransformParameters'][:] warp = warp.reshape((193, 229, 193, 3)).transpose(2, 1, 0, 3) warp *= np.array([-1, -1, 1]) @@ -533,7 +537,8 @@ def resample_bold( The BOLD series resampled into the target space """ # HMC goes last - assert isinstance(transforms[-1], nt.linear.LinearTransformsMapping) + if not isinstance(transforms[-1], nt.linear.LinearTransformsMapping): + raise ValueError('Last transform must be a linear mapping') # Retrieve the RAS coordinates of the target space coordinates = ( @@ -645,8 +650,8 @@ def main( hmc = derivs.get( extension='.txt', **mkents('orig', 'boldref', **entities) )[0] - except IndexError: - raise ValueError('Could not find HMC transforms') + except IndexError as err: + raise ValueError('Could not find HMC transforms') from err bold_xfms.append(hmc) @@ -659,8 +664,8 @@ def main( coreg = derivs.get( extension='.txt', **mkents('boldref', 'T1w', **entities) )[0] - except IndexError: - raise ValueError('Could not find coregistration transform') + except IndexError as err: + raise ValueError('Could not find coregistration transform') from err bold_xfms.append(coreg) fmap_xfms.append(coreg) @@ -682,10 +687,8 @@ def main( subject=entities['subject'], **mkents('T1w', space), )[0] - except IndexError: - raise ValueError( - f'Could not find template registration for {space}' - ) + except IndexError as err : + raise ValueError(f'Could not find template registration for {space}') from err bold_xfms.append(template_reg) fmap_xfms.append(template_reg) From 707c0c0456e3d901da0d250271b68847ee97a75a Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 7 May 2024 16:09:17 -0400 Subject: [PATCH 8/9] Update resampler.py --- src/fmripost_aroma/utils/resampler.py | 56 +++++++-------------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/src/fmripost_aroma/utils/resampler.py b/src/fmripost_aroma/utils/resampler.py index af2b0b3..7df2174 100644 --- a/src/fmripost_aroma/utils/resampler.py +++ b/src/fmripost_aroma/utils/resampler.py @@ -312,10 +312,7 @@ def parse_combined_hdf5(h5_fn, to_ras=True): xform = ITKCompositeH5.from_h5obj(h) affine = xform[0].to_ras() # Confirm these transformations are applicable - if ( - h['TransformGroup']['2']['TransformType'][:][0] - != b'DisplacementFieldTransform_float_3_3' - ): + if h['TransformGroup']['2']['TransformType'][:][0] != b'DisplacementFieldTransform_float_3_3': raise ValueError('Unsupported transform type') if not np.array_equal( @@ -404,9 +401,7 @@ def as_affine(xfm: nt.base.TransformBase) -> nt.Affine | None: if isinstance(xfm, nt.Affine): return xfm - if isinstance(xfm, nt.TransformChain) and all( - isinstance(x, nt.Affine) for x in xfm - ): + if isinstance(xfm, nt.TransformChain) and all(isinstance(x, nt.Affine) for x in xfm): return xfm.asaffine() return None @@ -466,9 +461,7 @@ def resample_fieldmap( ) else: if not aligned(fmap_reference.affine, coefficients[-1].affine): - raise ValueError( - 'Reference passed is not aligned with spline grids' - ) + raise ValueError('Reference passed is not aligned with spline grids') reference, _ = ensure_positive_cosines(fmap_reference) # Generate tensor-product B-Spline weights @@ -476,10 +469,7 @@ def resample_fieldmap( [grid_bspline_weights(reference, level) for level in coefficients] ).tocsr() coefficients = np.hstack( - [ - level.get_fdata(dtype='float32').reshape(-1) - for level in coefficients - ] + [level.get_fdata(dtype='float32').reshape(-1) for level in coefficients] ) # Reconstruct the fieldmap (in Hz) from coefficients @@ -493,9 +483,7 @@ def resample_fieldmap( fmap_img.header.set_intent('estimate', name='fieldmap Hz') fmap_img.header.set_data_dtype('float32') - fmap_img.header['cal_max'] = max( - (abs(fmap_img.dataobj.min()), fmap_img.dataobj.max()) - ) + fmap_img.header['cal_max'] = max((abs(fmap_img.dataobj.min()), fmap_img.dataobj.max())) fmap_img.header['cal_min'] = -fmap_img.header['cal_max'] return fmap_img @@ -541,9 +529,7 @@ def resample_bold( raise ValueError('Last transform must be a linear mapping') # Retrieve the RAS coordinates of the target space - coordinates = ( - nt.base.SpatialReference.factory(target).ndcoords.astype('f4').T - ) + coordinates = nt.base.SpatialReference.factory(target).ndcoords.astype('f4').T # We will operate in voxel space, so get the source affine vox2ras = source.affine @@ -559,9 +545,7 @@ def resample_bold( # Some identities to reduce special casing downstream if fieldmap is None: - fieldmap = nb.Nifti1Image( - np.zeros(target.shape[:3], dtype='f4'), target.affine - ) + fieldmap = nb.Nifti1Image(np.zeros(target.shape[:3], dtype='f4'), target.affine) if pe_info is None: pe_info = [[0, 0] for _ in range(source.shape[-1])] @@ -574,9 +558,7 @@ def resample_bold( output_dtype='f4', nthreads=nthreads, ) - resampled_img = nb.Nifti1Image( - resampled_data, target.affine, target.header - ) + resampled_img = nb.Nifti1Image(resampled_data, target.affine, target.header) resampled_img.set_data_dtype('f4') return resampled_img @@ -647,23 +629,17 @@ def main( fmap_xfms = [] try: - hmc = derivs.get( - extension='.txt', **mkents('orig', 'boldref', **entities) - )[0] + hmc = derivs.get(extension='.txt', **mkents('orig', 'boldref', **entities))[0] except IndexError as err: raise ValueError('Could not find HMC transforms') from err bold_xfms.append(hmc) if space == 'boldref': - reference = derivs.get( - desc='coreg', suffix='boldref', extension='.nii.gz', **entities - )[0] + reference = derivs.get(desc='coreg', suffix='boldref', extension='.nii.gz', **entities)[0] else: try: - coreg = derivs.get( - extension='.txt', **mkents('boldref', 'T1w', **entities) - )[0] + coreg = derivs.get(extension='.txt', **mkents('boldref', 'T1w', **entities))[0] except IndexError as err: raise ValueError('Could not find coregistration transform') from err @@ -687,7 +663,7 @@ def main( subject=entities['subject'], **mkents('T1w', space), )[0] - except IndexError as err : + except IndexError as err: raise ValueError(f'Could not find template registration for {space}') from err bold_xfms.append(template_reg) @@ -712,15 +688,11 @@ def main( ) ref_img = genref(nb.load(reference), zooms) - fmapregs = derivs.get( - extension='.txt', **mkents('boldref', derivs.get_fmapids(), **entities) - ) + fmapregs = derivs.get(extension='.txt', **mkents('boldref', derivs.get_fmapids(), **entities)) if not fmapregs: print('No fieldmap registrations found') elif len(fmapregs) > 1: - raise ValueError( - f'Found fieldmap registrations: {fmapregs}\nPass one as an argument.' - ) + raise ValueError(f'Found fieldmap registrations: {fmapregs}\nPass one as an argument.') fieldmap = None if fmapregs: From cb389d093b5dd33d1e4144b3b7f2dd977ffec7f9 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 14 May 2024 15:16:57 -0400 Subject: [PATCH 9/9] Work on resampling workflow. --- src/fmripost_aroma/workflows/base.py | 42 +++++++++++++++------- src/fmripost_aroma/workflows/resampling.py | 42 ++++++++++++++++------ 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index 76b8948..cce0b46 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -41,7 +41,7 @@ from fmripost_aroma import config from fmripost_aroma.interfaces.bids import DerivativesDataSink from fmripost_aroma.interfaces.reportlets import AboutSummary, SubjectSummary -from fmripost_aroma.workflows.resampling import init_resample_raw_wf +from fmripost_aroma.workflows.resampling import init_resample_volumetric_wf def init_fmripost_aroma_wf(): @@ -180,10 +180,13 @@ def init_single_subject_wf(subject_id: str): from niworkflows.interfaces.bids import BIDSDataGrabber, BIDSInfo from niworkflows.interfaces.nilearn import NILEARN_VERSION from niworkflows.utils.misc import fix_multi_T1w_source_name + from niworkflows.utils.spaces import Reference from fmripost_aroma.utils.bids import collect_derivatives from fmripost_aroma.workflows.aroma import init_ica_aroma_wf + spaces = config.workflow.spaces + workflow = Workflow(name=f'sub_{subject_id}_wf') workflow.__desc__ = f""" Results included in this manuscript come from preprocessing @@ -217,16 +220,10 @@ def init_single_subject_wf(subject_id: str): """ subject_data = collect_derivatives( - config.execution.layout, - subject_id, + raw_dir=config.execution.layout, entities=config.execution.bids_filters, ) - if 'flair' in config.workflow.ignore: - subject_data['flair'] = [] - if 't2w' in config.workflow.ignore: - subject_data['t2w'] = [] - anat_only = config.workflow.anat_only # Make sure we always go through these two checks if not anat_only and not subject_data['bold']: @@ -258,7 +255,8 @@ def init_single_subject_wf(subject_id: str): ) bids_info = pe.Node( - BIDSInfo(bids_dir=config.execution.bids_dir, bids_validate=False), name='bids_info' + BIDSInfo(bids_dir=config.execution.bids_dir, bids_validate=False), + name='bids_info', ) summary = pe.Node( @@ -338,9 +336,11 @@ def init_single_subject_wf(subject_id: str): ) ) - resample_raw_wf = init_resample_raw_wf( + # Resample to MNI152NLin6Asym:res-2, for ICA-AROMA classification + resample_raw_wf = init_resample_volumetric_wf( bold_file=bold_file, precomputed=functional_cache, + space=Reference.from_string("MNI152NLin6Asym:res-2")[0], ) workflow.connect([ (resample_raw_wf, ica_aroma_wf, [ @@ -349,7 +349,7 @@ def init_single_subject_wf(subject_id: str): ]), ]) # fmt:skip else: - # Collect standard-space derivatives + # Collect MNI152NLin6Asym:res-2 derivatives from fmripost_aroma.utils.bids import collect_derivatives functional_cache.update( @@ -367,8 +367,24 @@ def init_single_subject_wf(subject_id: str): ica_aroma_wf.inputs.inputnode.spatial_reference = functional_cache['spatial_reference'] # Now denoise the native-space BOLD data using ICA-AROMA - - # Now warp the denoised BOLD data to the requested output spaces + denoise_native_wf = init_denoise_wf(bold_file=bold_file) + workflow.connect([ + (ica_aroma_wf, denoise_native_wf, [ + ('outputnode.aroma_noise_ics', 'inputnode.aroma_noise_ics'), + ]), + ]) # fmt:skip + + for space in spaces: + resample_to_space_wf = init_resample_volumetric_wf( + bold_file=bold_file, + functional_cache=functional_cache, + space=space, + ) + workflow.connect([ + (denoise_native_wf, resample_to_space_wf, [ + ('outputnode.denoised_file', 'inputnode.bold_file'), + ]), + ]) # fmt:skip return clean_datasinks(workflow) diff --git a/src/fmripost_aroma/workflows/resampling.py b/src/fmripost_aroma/workflows/resampling.py index ae4753c..1c8627f 100644 --- a/src/fmripost_aroma/workflows/resampling.py +++ b/src/fmripost_aroma/workflows/resampling.py @@ -4,8 +4,20 @@ from nipype.pipeline import engine as pe -def init_resample_raw_wf(bold_file, functional_cache): - """Resample raw BOLD data to MNI152NLin6Asym:res-2mm space.""" +def init_resample_volumetric_wf(bold_file, functional_cache, space, run_stc): + """Resample raw BOLD data to requested volumetric space space. + + Parameters + ---------- + bold_file : str + Path to BOLD file. + functional_cache : dict + Dictionary with paths to functional data. + space : niworkflows.utils.spaces.Reference + Spatial reference. + run_stc : bool + Whether to run STC. + """ from fmriprep.workflows.bold.stc import init_bold_stc_wf from niworkflows.engine.workflows import LiterateWorkflow as Workflow @@ -25,20 +37,28 @@ def init_resample_raw_wf(bold_file, functional_cache): name='outputnode', ) - stc_wf = init_bold_stc_wf(name='resample_stc_wf') - workflow.connect([ - (inputnode, stc_wf, [ - ('bold_file', 'inputnode.bold_file'), - ('mask_file', 'inputnode.mask_file'), - ]), - ]) # fmt:skip + stc_buffer = pe.Node( + niu.IdentityInterface(fields=['bold_file']), + name='stc_buffer', + ) + if run_stc: + stc_wf = init_bold_stc_wf(name='resample_stc_wf') + workflow.connect([ + (inputnode, stc_wf, [ + ('bold_file', 'inputnode.bold_file'), + ('mask_file', 'inputnode.mask_file'), + ]), + (stc_wf, stc_buffer, [('outputnode.bold_file', 'bold_file')]), + ]) # fmt:skip + else: + workflow.connect([(inputnode, stc_buffer, [('bold_file', 'bold_file')])]) resample_bold = pe.Node( - Resampler(space='MNI152NLin6Asym', resolution='2'), + Resampler(space=space.space, **space.spec), name='resample_bold', ) workflow.connect([ - (stc_wf, resample_bold, [('outputnode.bold_file', 'bold_file')]), + (stc_buffer, resample_bold, [('outputnode.bold_file', 'bold_file')]), (resample_bold, outputnode, [('output_file', 'bold_std')]), ]) # fmt:skip