Skip to content

Commit

Permalink
ENH: Restore resampling to T1w target (#3116)
Browse files Browse the repository at this point in the history
## Changes proposed in this pull request

Adds T1w resampling. Unconditionally resamples in T1w space, but
conditionally outputs.
  • Loading branch information
effigies authored Oct 26, 2023
2 parents 978ae51 + 0d09cc0 commit 6bda5ce
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 712 deletions.
70 changes: 59 additions & 11 deletions fmriprep/interfaces/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class ResampleSeriesInputSpec(TraitedSpec):
in_file = File(exists=True, mandatory=True, desc="3D or 4D image file to resample")
ref_file = File(exists=True, mandatory=True, desc="File to resample in_file to")
transforms = InputMultiObject(
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
File(exists=True),
mandatory=True,
desc="Transform files, from in_file to ref_file (image mode)",
)
inverse = InputMultiObject(
traits.Bool,
Expand All @@ -48,6 +50,16 @@ class ResampleSeriesInputSpec(TraitedSpec):
desc="the phase-encoding direction corresponding to in_data",
)
num_threads = traits.Int(1, usedefault=True, desc="Number of threads to use for resampling")
output_data_type = traits.Str("float32", usedefault=True, desc="Data type of output image")
order = traits.Int(3, usedefault=True, desc="Order of interpolation (0=nearest, 3=cubic)")
mode = traits.Str(
'constant',
usedefault=True,
desc="How data is extended beyond its boundaries. "
"See scipy.ndimage.map_coordinates for more details.",
)
cval = traits.Float(0.0, usedefault=True, desc="Value to fill past edges of data")
prefilter = traits.Bool(True, usedefault=True, desc="Spline-prefilter data if order > 1")


class ResampleSeriesOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -87,13 +99,18 @@ def _run_interface(self, runtime):

pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols

Check warning on line 100 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L100

Added line #L100 was not covered by tests

resampled = resample_bold(
resampled = resample_image(

Check warning on line 102 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L102

Added line #L102 was not covered by tests
source=source,
target=target,
transforms=transforms,
fieldmap=fieldmap,
pe_info=pe_info,
nthreads=self.inputs.num_threads,
output_dtype=self.inputs.output_data_type,
order=self.inputs.order,
mode=self.inputs.mode,
cval=self.inputs.cval,
prefilter=self.inputs.prefilter,
)
resampled.to_filename(out_path)

Check warning on line 115 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L115

Added line #L115 was not covered by tests

Expand All @@ -105,10 +122,16 @@ class ReconstructFieldmapInputSpec(TraitedSpec):
in_coeffs = InputMultiObject(
File(exists=True), mandatory=True, desc="SDCflows-style spline coefficient files"
)
target_ref_file = File(exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with")
fmap_ref_file = File(exists=True, mandatory=True, desc="Reference file aligned with coefficients")
target_ref_file = File(
exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with"
)
fmap_ref_file = File(
exists=True, mandatory=True, desc="Reference file aligned with coefficients"
)
transforms = InputMultiObject(
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
File(exists=True),
mandatory=True,
desc="Transform files, from in_file to ref_file (image mode)",
)
inverse = InputMultiObject(
traits.Bool,
Expand Down Expand Up @@ -252,6 +275,9 @@ def resample_vol(
coordinates = nb.affines.apply_affine(

Check warning on line 275 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L274-L275

Added lines #L274 - L275 were not covered by tests
hmc_xfm, coordinates.reshape(coords_shape[0], -1).T
).T.reshape(coords_shape)
else:
# Copy coordinates to avoid interfering with other calls
coordinates = coordinates.copy()

Check warning on line 280 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L280

Added line #L280 was not covered by tests

vsm = fmap_hz * pe_info[1]
coordinates[pe_info[0], ...] += vsm

Check warning on line 283 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L282-L283

Added lines #L282 - L283 were not covered by tests
Expand Down Expand Up @@ -346,15 +372,17 @@ async def resample_series_async(

semaphore = asyncio.Semaphore(max_concurrent)

Check warning on line 373 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L373

Added line #L373 was not covered by tests

out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype)
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype, order='F')

Check warning on line 377 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L377

Added line #L377 was not covered by tests

tasks = [

Check warning on line 379 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L379

Added line #L379 was not covered by tests
asyncio.create_task(
worker(
partial(
resample_vol,
data=volume,
coordinates=coordinates.copy(),
coordinates=coordinates,
pe_info=pe_info[volid],
hmc_xfm=hmc_xfms[volid] if hmc_xfms else None,
fmap_hz=fmap_hz,
Expand Down Expand Up @@ -451,21 +479,26 @@ def resample_series(
)


def resample_bold(
def resample_image(
source: nb.Nifti1Image,
target: nb.Nifti1Image,
transforms: nt.TransformChain,
fieldmap: nb.Nifti1Image | None,
pe_info: list[tuple[int, float]] | None,
nthreads: int = 1,
output_dtype: np.dtype | str | None = 'f4',
order: int = 3,
mode: str = 'constant',
cval: float = 0.0,
prefilter: bool = True,
) -> nb.Nifti1Image:
"""Resample a 4D bold series into a target space, applying head-motion
"""Resample a 3- or 4D image into a target space, applying head-motion
and susceptibility-distortion correction simultaneously.
Parameters
----------
source
The 4D bold series to resample.
The 3D bold image or 4D bold series to resample.
target
An image sampled in the target space.
transforms
Expand All @@ -480,6 +513,17 @@ def resample_bold(
of the data array in the second dimension.
nthreads
Number of threads to use for parallel resampling
output_dtype
The dtype of the output array.
order
Order of interpolation (default: 3 = cubic)
mode
How ``data`` is extended beyond its boundaries. See
:func:`scipy.ndimage.map_coordinates` for more details.
cval
Value to fill past edges of ``data`` if ``mode`` is ``'constant'``.
prefilter
Determines if ``data`` is pre-filtered before interpolation.
Returns
-------
Expand Down Expand Up @@ -527,8 +571,12 @@ def resample_bold(
pe_info=pe_info,
hmc_xfms=hmc_xfms,
fmap_hz=fieldmap.get_fdata(dtype='f4'),
output_dtype='f4',
output_dtype=output_dtype,
nthreads=nthreads,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
resampled_img = nb.Nifti1Image(resampled_data, target.affine, target.header)
resampled_img.set_data_dtype('f4')

Check warning on line 582 in fmriprep/interfaces/resampling.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/interfaces/resampling.py#L581-L582

Added lines #L581 - L582 were not covered by tests
Expand Down
3 changes: 3 additions & 0 deletions fmriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def init_single_subject_wf(subject_id: str):
precomputed=functional_cache,
fieldmap_id=fieldmap_id,
)
if bold_wf is None:

Check warning on line 486 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L486

Added line #L486 was not covered by tests
continue

bold_wf.__desc__ = func_pre_desc + (bold_wf.__desc__ or "")

Check warning on line 489 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L489

Added line #L489 was not covered by tests

workflow.connect([
Expand Down
118 changes: 118 additions & 0 deletions fmriprep/workflows/bold/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as pe
from niworkflows.interfaces.header import ValidateImage
from niworkflows.interfaces.nibabel import GenerateSamplingReference
from niworkflows.interfaces.utility import KeySelect
from niworkflows.utils.connections import listify

Expand All @@ -25,6 +26,110 @@
from niworkflows.utils.spaces import SpatialReferences

Check warning on line 26 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L26

Added line #L26 was not covered by tests


def init_bold_volumetric_resample_wf(
*,
metadata: dict,
fieldmap_id: str | None = None,
omp_nthreads: int = 1,
name: str = 'bold_volumetric_resample_wf',
) -> pe.Workflow:
workflow = pe.Workflow(name=name)

Check warning on line 36 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L36

Added line #L36 was not covered by tests

inputnode = pe.Node(

Check warning on line 38 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L38

Added line #L38 was not covered by tests
niu.IdentityInterface(
fields=[
"bold_file",
"bold_ref_file",
"target_ref_file",
"target_mask",
# HMC
"motion_xfm",
# SDC
"boldref2fmap_xfm",
"fmap_ref",
"fmap_coeff",
"fmap_id",
# Anatomical
"boldref2anat_xfm",
# Template
"anat2std_xfm",
],
),
name='inputnode',
)

outputnode = pe.Node(niu.IdentityInterface(fields=["bold_file"]), name='outputnode')

Check warning on line 61 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L61

Added line #L61 was not covered by tests

gen_ref = pe.Node(GenerateSamplingReference(), name='gen_ref', mem_gb=0.3)

Check warning on line 63 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L63

Added line #L63 was not covered by tests

boldref2target = pe.Node(niu.Merge(2), name='boldref2target')
bold2target = pe.Node(niu.Merge(2), name='bold2target')
resample = pe.Node(ResampleSeries(), name="resample", n_procs=omp_nthreads)

Check warning on line 67 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L65-L67

Added lines #L65 - L67 were not covered by tests

workflow.connect([

Check warning on line 69 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L69

Added line #L69 was not covered by tests
(inputnode, gen_ref, [
('bold_ref_file', 'moving_image'),
('target_ref_file', 'fixed_image'),
('target_mask', 'fov_mask'),
]),
(inputnode, boldref2target, [
('boldref2anat_xfm', 'in1'),
('anat2std_xfm', 'in2'),
]),
(inputnode, bold2target, [('motion_xfm', 'in1')]),
(inputnode, resample, [('bold_file', 'in_file')]),
(gen_ref, resample, [('out_file', 'ref_file')]),
(boldref2target, bold2target, [('out', 'in2')]),
(bold2target, resample, [('out', 'transforms')]),
(resample, outputnode, [('out_file', 'bold_file')]),
]) # fmt:skip

if not fieldmap_id:
return workflow

Check warning on line 88 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L87-L88

Added lines #L87 - L88 were not covered by tests

fmap_select = pe.Node(

Check warning on line 90 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L90

Added line #L90 was not covered by tests
KeySelect(fields=["fmap_ref", "fmap_coeff"], key=fieldmap_id),
name="fmap_select",
run_without_submitting=True,
)
distortion_params = pe.Node(

Check warning on line 95 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L95

Added line #L95 was not covered by tests
DistortionParameters(metadata=metadata),
name="distortion_params",
run_without_submitting=True,
)
fmap2target = pe.Node(niu.Merge(2), name='fmap2target')
inverses = pe.Node(niu.Function(function=_gen_inverses), name='inverses')

Check warning on line 101 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L100-L101

Added lines #L100 - L101 were not covered by tests

fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon")

Check warning on line 103 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L103

Added line #L103 was not covered by tests

workflow.connect([

Check warning on line 105 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L105

Added line #L105 was not covered by tests
(inputnode, fmap_select, [
("fmap_ref", "fmap_ref"),
("fmap_coeff", "fmap_coeff"),
("fmap_id", "keys"),
]),
(inputnode, distortion_params, [('bold_file', 'in_file')]),
(inputnode, fmap2target, [('boldref2fmap_xfm', 'in1')]),
(gen_ref, fmap_recon, [('out_file', 'target_ref_file')]),
(boldref2target, fmap2target, [('out', 'in2')]),
(boldref2target, inverses, [('out', 'inlist')]),
(fmap_select, fmap_recon, [
("fmap_coeff", "in_coeffs"),
("fmap_ref", "fmap_ref_file"),
]),
(fmap2target, fmap_recon, [('out', 'transforms')]),
(inverses, fmap_recon, [('out', 'inverse')]),
# Inject fieldmap correction into resample node
(distortion_params, resample, [
("readout_time", "ro_time"),
("pe_direction", "pe_dir"),
]),
(fmap_recon, resample, [('out_file', 'fieldmap')]),
]) # fmt:skip

return workflow

Check warning on line 130 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L130

Added line #L130 was not covered by tests


def init_bold_apply_wf(
*,
spaces: SpatialReferences,
Expand All @@ -49,3 +154,16 @@ def init_bold_apply_wf(
# )

return workflow

Check warning on line 156 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L156

Added line #L156 was not covered by tests


def _gen_inverses(inlist: list) -> list[bool]:
"""Create a list indicating the first transform should be inverted.
The input list is the collection of transforms that follow the
inverted one.
"""
from niworkflows.utils.connections import listify

Check warning on line 165 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L165

Added line #L165 was not covered by tests

if not inlist:
return [True]
return [True] + [False] * len(listify(inlist))

Check warning on line 169 in fmriprep/workflows/bold/apply.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/apply.py#L167-L169

Added lines #L167 - L169 were not covered by tests
Loading

0 comments on commit 6bda5ce

Please sign in to comment.