Skip to content

Commit

Permalink
ENH: Restore resampling BOLD to volumetric templates (#3121)
Browse files Browse the repository at this point in the history
Builds on #3116.
  • Loading branch information
mgxd authored Nov 3, 2023
2 parents 6bda5ce + 2fd3012 commit ce7c65f
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 133 deletions.
112 changes: 66 additions & 46 deletions fmriprep/utils/transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for loading transforms for resampling"""
import warnings
from pathlib import Path

import h5py
Expand Down Expand Up @@ -36,57 +37,76 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
return chain

Check warning on line 37 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L34-L37

Added lines #L34 - L37 were not covered by tests


def load_ants_h5(filename: Path) -> nt.TransformChain:
"""Load ANTs H5 files as a nitransforms TransformChain"""
affine, warp, warp_affine = parse_combined_hdf5(filename)
warp_transform = nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine))
return nt.TransformChain([warp_transform, nt.Affine(affine)])
FIXED_PARAMS = np.array([
193.0, 229.0, 193.0, # Size
96.0, 132.0, -78.0, # Origin
1.0, 1.0, 1.0, # Spacing
-1.0, 0.0, 0.0, # Directions
0.0, -1.0, 0.0,
0.0, 0.0, 1.0,
]) # fmt:skip


def parse_combined_hdf5(h5_fn, to_ras=True):
def load_ants_h5(filename: Path) -> nt.base.TransformBase:
"""Load ANTs H5 files as a nitransforms TransformChain"""
# Borrowed from https://github.com/feilong/process
# process.resample.parse_combined_hdf5()
h = h5py.File(h5_fn)
#
# Changes:
# * Tolerate a missing displacement field
# * Return the original affine without a round-trip
# * Always return a nitransforms TransformChain
#
# This should be upstreamed into nitransforms
h = h5py.File(filename)
xform = ITKCompositeH5.from_h5obj(h)

Check warning on line 62 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L61-L62

Added lines #L61 - L62 were not covered by tests
affine = xform[0].to_ras()

# nt.Affine
transforms = [nt.Affine(xform[0].to_ras())]

Check warning on line 65 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L65

Added line #L65 was not covered by tests

if '2' not in h['TransformGroup']:
return transforms[0]

Check warning on line 68 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L67-L68

Added lines #L67 - L68 were not covered by tests

transform2 = h['TransformGroup']['2']

Check warning on line 70 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L70

Added line #L70 was not covered by tests

# Confirm these transformations are applicable
assert (
h['TransformGroup']['2']['TransformType'][:][0] == b'DisplacementFieldTransform_float_3_3'
)
assert np.array_equal(
h['TransformGroup']['2']['TransformFixedParameters'][:],
np.array(
[
193.0,
229.0,
193.0,
96.0,
132.0,
-78.0,
1.0,
1.0,
1.0,
-1.0,
0.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
1.0,
]
),
)
if transform2['TransformType'][:][0] != b'DisplacementFieldTransform_float_3_3':
msg = 'Unknown transform type [2]\n'
for i in h['TransformGroup'].keys():
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
raise ValueError(msg)

Check warning on line 77 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L73-L77

Added lines #L73 - L77 were not covered by tests

fixed_params = transform2['TransformFixedParameters'][:]
if not np.array_equal(fixed_params, FIXED_PARAMS):
msg = 'Unexpected fixed parameters\n'
msg += f'Expected: {FIXED_PARAMS}\n'
msg += f'Found: {fixed_params}'
if not np.array_equal(fixed_params[6:], FIXED_PARAMS[6:]):
raise ValueError(msg)
warnings.warn(msg)

Check warning on line 86 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L79-L86

Added lines #L79 - L86 were not covered by tests

shape = tuple(fixed_params[:3].astype(int))
warp = h['TransformGroup']['2']['TransformParameters'][:]
warp = warp.reshape((193, 229, 193, 3)).transpose(2, 1, 0, 3)
warp = warp.reshape((*shape, 3)).transpose(2, 1, 0, 3)
warp *= np.array([-1, -1, 1])

Check warning on line 91 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L88-L91

Added lines #L88 - L91 were not covered by tests
warp_affine = np.array(
[
[1.0, 0.0, 0.0, -96.0],
[0.0, 1.0, 0.0, -132.0],
[0.0, 0.0, 1.0, -78.0],
[0.0, 0.0, 0.0, 1.0],
]
)
return affine, warp, warp_affine

warp_affine = np.eye(4)
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
warp_affine[:3, 3] = fixed_params[3:6]
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
warp_affine = lps_to_ras @ warp_affine
if np.array_equal(fixed_params, FIXED_PARAMS):

Check warning on line 98 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L93-L98

Added lines #L93 - L98 were not covered by tests
# Confirm that we construct the right affine when fixed parameters are known
assert np.array_equal(

Check warning on line 100 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L100

Added line #L100 was not covered by tests
warp_affine,
np.array(
[
[1.0, 0.0, 0.0, -96.0],
[0.0, 1.0, 0.0, -132.0],
[0.0, 0.0, 1.0, -78.0],
[0.0, 0.0, 0.0, 1.0],
]
),
)
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
return nt.TransformChain(transforms)

Check warning on line 112 in fmriprep/utils/transforms.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/utils/transforms.py#L111-L112

Added lines #L111 - L112 were not covered by tests
28 changes: 25 additions & 3 deletions fmriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def init_single_subject_wf(subject_id: str):
from niworkflows.utils.misc import fix_multi_T1w_source_name
from niworkflows.utils.spaces import Reference
from smriprep.workflows.anatomical import init_anat_fit_wf
from smriprep.workflows.outputs import init_template_iterator_wf

Check warning on line 154 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L153-L154

Added lines #L153 - L154 were not covered by tests

from fmriprep.workflows.bold.base import init_bold_wf

Check warning on line 156 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L156

Added line #L156 was not covered by tests

Expand Down Expand Up @@ -310,7 +311,6 @@ def init_single_subject_wf(subject_id: str):
skull_strip_fixed_seed=config.workflow.skull_strip_fixed_seed,
)

# fmt:off
workflow.connect([
(inputnode, anat_fit_wf, [('subjects_dir', 'inputnode.subjects_dir')]),
(bidssrc, bids_info, [(('t1w', fix_multi_T1w_source_name), 'in_file')]),
Expand All @@ -329,8 +329,18 @@ def init_single_subject_wf(subject_id: str):
(bidssrc, ds_report_about, [(('t1w', fix_multi_T1w_source_name), 'source_file')]),
(summary, ds_report_summary, [('out_report', 'in_file')]),
(about, ds_report_about, [('out_report', 'in_file')]),
])
# fmt:on
]) # fmt:skip

# Set up the template iterator once, if used
if config.workflow.level == "full":
if spaces.get_spaces(nonstandard=False, dim=(3,)):
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
workflow.connect([

Check warning on line 338 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L335-L338

Added lines #L335 - L338 were not covered by tests
(anat_fit_wf, template_iterator_wf, [
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
]),
]) # fmt:skip

if config.workflow.anat_only:
return clean_datasinks(workflow)

Check warning on line 346 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L345-L346

Added lines #L345 - L346 were not covered by tests
Expand Down Expand Up @@ -510,6 +520,18 @@ def init_single_subject_wf(subject_id: str):
]),
]) # fmt:skip

if config.workflow.level == "full":
workflow.connect([

Check warning on line 524 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L523-L524

Added lines #L523 - L524 were not covered by tests
(template_iterator_wf, bold_wf, [
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
("outputnode.space", "inputnode.std_space"),
("outputnode.resolution", "inputnode.std_resolution"),
("outputnode.cohort", "inputnode.std_cohort"),
("outputnode.std_t1w", "inputnode.std_t1w"),
("outputnode.std_mask", "inputnode.std_mask"),
]),
]) # fmt:skip

return clean_datasinks(workflow)

Check warning on line 535 in fmriprep/workflows/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/base.py#L535

Added line #L535 was not covered by tests


Expand Down
144 changes: 60 additions & 84 deletions fmriprep/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def init_bold_wf(
"fmap_mask",
"fmap_id",
"sdc_method",
# Volumetric templates
"anat2std_xfm",
"std_space",
"std_resolution",
"std_cohort",
"std_t1w",
"std_mask",
],
),
name="inputnode",
Expand Down Expand Up @@ -381,6 +388,59 @@ def init_bold_wf(
(bold_anat_wf, ds_bold_t1_wf, [('outputnode.bold_file', 'inputnode.bold')]),
]) # fmt:skip

if spaces.get_spaces(nonstandard=False, dim=(3,)):

Check warning on line 391 in fmriprep/workflows/bold/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/base.py#L391

Added line #L391 was not covered by tests
# Missing:
# * Clipping BOLD after resampling
# * Resampling parcellations
bold_std_wf = init_bold_volumetric_resample_wf(

Check warning on line 395 in fmriprep/workflows/bold/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/base.py#L395

Added line #L395 was not covered by tests
metadata=all_metadata[0],
fieldmap_id=fieldmap_id if not multiecho else None,
omp_nthreads=omp_nthreads,
name='bold_std_wf',
)
ds_bold_std_wf = init_ds_volumes_wf(

Check warning on line 401 in fmriprep/workflows/bold/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/base.py#L401

Added line #L401 was not covered by tests
bids_root=str(config.execution.bids_dir),
output_dir=fmriprep_dir,
multiecho=multiecho,
metadata=all_metadata[0],
name='ds_bold_std_wf',
)
ds_bold_std_wf.inputs.inputnode.source_files = bold_series

Check warning on line 408 in fmriprep/workflows/bold/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/base.py#L408

Added line #L408 was not covered by tests

workflow.connect([

Check warning on line 410 in fmriprep/workflows/bold/base.py

View check run for this annotation

Codecov / codecov/patch

fmriprep/workflows/bold/base.py#L410

Added line #L410 was not covered by tests
(inputnode, bold_std_wf, [
("std_t1w", "inputnode.target_ref_file"),
("std_mask", "inputnode.target_mask"),
("anat2std_xfm", "inputnode.anat2std_xfm"),
("fmap_ref", "inputnode.fmap_ref"),
("fmap_coeff", "inputnode.fmap_coeff"),
("fmap_id", "inputnode.fmap_id"),
]),
(bold_fit_wf, bold_std_wf, [
("outputnode.coreg_boldref", "inputnode.bold_ref_file"),
("outputnode.boldref2fmap_xfm", "inputnode.boldref2fmap_xfm"),
("outputnode.boldref2anat_xfm", "inputnode.boldref2anat_xfm"),
]),
(bold_native_wf, bold_std_wf, [
("outputnode.bold_minimal", "inputnode.bold_file"),
("outputnode.motion_xfm", "inputnode.motion_xfm"),
]),
(inputnode, ds_bold_std_wf, [
('std_t1w', 'inputnode.ref_file'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std_space', 'inputnode.space'),
('std_resolution', 'inputnode.resolution'),
('std_cohort', 'inputnode.cohort'),
]),
(bold_fit_wf, ds_bold_std_wf, [
('outputnode.bold_mask', 'inputnode.bold_mask'),
('outputnode.coreg_boldref', 'inputnode.bold_ref'),
('outputnode.boldref2anat_xfm', 'inputnode.boldref2anat_xfm'),
]),
(bold_native_wf, ds_bold_std_wf, [('outputnode.t2star_map', 'inputnode.t2star')]),
(bold_std_wf, ds_bold_std_wf, [('outputnode.bold_file', 'inputnode.bold')]),
]) # fmt:skip

# Fill-in datasinks of reportlets seen so far
for node in workflow.list_node_names():
if node.split(".")[-1].startswith("ds_report"):
Expand Down Expand Up @@ -629,90 +689,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
)
bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False]

if spaces.get_spaces(nonstandard=False, dim=(3,)):
# Apply transforms in 1 shot
bold_std_trans_wf = init_bold_std_trans_wf(
freesurfer=freesurfer,
mem_gb=mem_gb["resampled"],
omp_nthreads=omp_nthreads,
spaces=spaces,
multiecho=multiecho,
name="bold_std_trans_wf",
use_compression=not config.execution.low_mem,
)
bold_std_trans_wf.inputs.inputnode.fieldwarp = "identity"

# fmt:off
workflow.connect([
(inputnode, bold_std_trans_wf, [
("template", "inputnode.templates"),
("anat2std_xfm", "inputnode.anat2std_xfm"),
("bold_file", "inputnode.name_source"),
("t1w_aseg", "inputnode.bold_aseg"),
("t1w_aparc", "inputnode.bold_aparc"),
]),
(bold_final, bold_std_trans_wf, [
("mask", "inputnode.bold_mask"),
("t2star", "inputnode.t2star"),
]),
(bold_reg_wf, bold_std_trans_wf, [
("outputnode.itk_bold_to_t1", "inputnode.itk_bold_to_t1"),
]),
(bold_std_trans_wf, outputnode, [
("outputnode.bold_std", "bold_std"),
("outputnode.bold_std_ref", "bold_std_ref"),
("outputnode.bold_mask_std", "bold_mask_std"),
]),
])
# fmt:on

if freesurfer:
# fmt:off
workflow.connect([
(bold_std_trans_wf, func_derivatives_wf, [
("outputnode.bold_aseg_std", "inputnode.bold_aseg_std"),
("outputnode.bold_aparc_std", "inputnode.bold_aparc_std"),
]),
(bold_std_trans_wf, outputnode, [
("outputnode.bold_aseg_std", "bold_aseg_std"),
("outputnode.bold_aparc_std", "bold_aparc_std"),
]),
])
# fmt:on

if not multiecho:
# fmt:off
workflow.connect([
(bold_split, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]),
(bold_hmc_wf, bold_std_trans_wf, [
("outputnode.xforms", "inputnode.hmc_xforms"),
]),
])
# fmt:on
else:
# fmt:off
workflow.connect([
(split_opt_comb, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]),
(bold_std_trans_wf, outputnode, [("outputnode.t2star_std", "t2star_std")]),
])
# fmt:on

# Already applied in bold_bold_trans_wf, which inputs to bold_t2s_wf
bold_std_trans_wf.inputs.inputnode.hmc_xforms = "identity"

# fmt:off
# func_derivatives_wf internally parametrizes over snapshotted spaces.
workflow.connect([
(bold_std_trans_wf, func_derivatives_wf, [
("outputnode.template", "inputnode.template"),
("outputnode.spatial_reference", "inputnode.spatial_reference"),
("outputnode.bold_std_ref", "inputnode.bold_std_ref"),
("outputnode.bold_std", "inputnode.bold_std"),
("outputnode.bold_mask_std", "inputnode.bold_mask_std"),
]),
])
# fmt:on

# SURFACES ##################################################################################
# Freesurfer
if freesurfer and freesurfer_spaces:
Expand Down

0 comments on commit ce7c65f

Please sign in to comment.