diff --git a/fmriprep/utils/transforms.py b/fmriprep/utils/transforms.py index 738b7f797..951d1d9f9 100644 --- a/fmriprep/utils/transforms.py +++ b/fmriprep/utils/transforms.py @@ -1,4 +1,5 @@ """Utilities for loading transforms for resampling""" +import warnings from pathlib import Path import h5py @@ -36,57 +37,76 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans return chain -def load_ants_h5(filename: Path) -> nt.TransformChain: - """Load ANTs H5 files as a nitransforms TransformChain""" - affine, warp, warp_affine = parse_combined_hdf5(filename) - warp_transform = nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)) - return nt.TransformChain([warp_transform, nt.Affine(affine)]) +FIXED_PARAMS = np.array([ + 193.0, 229.0, 193.0, # Size + 96.0, 132.0, -78.0, # Origin + 1.0, 1.0, 1.0, # Spacing + -1.0, 0.0, 0.0, # Directions + 0.0, -1.0, 0.0, + 0.0, 0.0, 1.0, +]) # fmt:skip -def parse_combined_hdf5(h5_fn, to_ras=True): +def load_ants_h5(filename: Path) -> nt.base.TransformBase: + """Load ANTs H5 files as a nitransforms TransformChain""" # Borrowed from https://github.com/feilong/process # process.resample.parse_combined_hdf5() - h = h5py.File(h5_fn) + # + # Changes: + # * Tolerate a missing displacement field + # * Return the original affine without a round-trip + # * Always return a nitransforms TransformChain + # + # This should be upstreamed into nitransforms + h = h5py.File(filename) xform = ITKCompositeH5.from_h5obj(h) - affine = xform[0].to_ras() + + # nt.Affine + transforms = [nt.Affine(xform[0].to_ras())] + + if '2' not in h['TransformGroup']: + return transforms[0] + + transform2 = h['TransformGroup']['2'] + # Confirm these transformations are applicable - assert ( - h['TransformGroup']['2']['TransformType'][:][0] == b'DisplacementFieldTransform_float_3_3' - ) - assert np.array_equal( - h['TransformGroup']['2']['TransformFixedParameters'][:], - np.array( - [ - 193.0, - 229.0, - 193.0, - 96.0, - 132.0, - -78.0, - 1.0, - 1.0, - 1.0, - -1.0, - 0.0, - 0.0, - 0.0, - -1.0, - 0.0, - 0.0, - 0.0, - 1.0, - ] - ), - ) + if transform2['TransformType'][:][0] != b'DisplacementFieldTransform_float_3_3': + msg = 'Unknown transform type [2]\n' + for i in h['TransformGroup'].keys(): + msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n' + raise ValueError(msg) + + fixed_params = transform2['TransformFixedParameters'][:] + if not np.array_equal(fixed_params, FIXED_PARAMS): + msg = 'Unexpected fixed parameters\n' + msg += f'Expected: {FIXED_PARAMS}\n' + msg += f'Found: {fixed_params}' + if not np.array_equal(fixed_params[6:], FIXED_PARAMS[6:]): + raise ValueError(msg) + warnings.warn(msg) + + shape = tuple(fixed_params[:3].astype(int)) warp = h['TransformGroup']['2']['TransformParameters'][:] - warp = warp.reshape((193, 229, 193, 3)).transpose(2, 1, 0, 3) + warp = warp.reshape((*shape, 3)).transpose(2, 1, 0, 3) warp *= np.array([-1, -1, 1]) - warp_affine = np.array( - [ - [1.0, 0.0, 0.0, -96.0], - [0.0, 1.0, 0.0, -132.0], - [0.0, 0.0, 1.0, -78.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - return affine, warp, warp_affine + + warp_affine = np.eye(4) + warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3)) + warp_affine[:3, 3] = fixed_params[3:6] + lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1]) + warp_affine = lps_to_ras @ warp_affine + if np.array_equal(fixed_params, FIXED_PARAMS): + # Confirm that we construct the right affine when fixed parameters are known + assert np.array_equal( + warp_affine, + np.array( + [ + [1.0, 0.0, 0.0, -96.0], + [0.0, 1.0, 0.0, -132.0], + [0.0, 0.0, 1.0, -78.0], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ) + transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine))) + return nt.TransformChain(transforms) diff --git a/fmriprep/workflows/base.py b/fmriprep/workflows/base.py index ec901e407..75cbc7c9d 100644 --- a/fmriprep/workflows/base.py +++ b/fmriprep/workflows/base.py @@ -151,6 +151,7 @@ def init_single_subject_wf(subject_id: str): from niworkflows.utils.misc import fix_multi_T1w_source_name from niworkflows.utils.spaces import Reference from smriprep.workflows.anatomical import init_anat_fit_wf + from smriprep.workflows.outputs import init_template_iterator_wf from fmriprep.workflows.bold.base import init_bold_wf @@ -310,7 +311,6 @@ def init_single_subject_wf(subject_id: str): skull_strip_fixed_seed=config.workflow.skull_strip_fixed_seed, ) - # fmt:off workflow.connect([ (inputnode, anat_fit_wf, [('subjects_dir', 'inputnode.subjects_dir')]), (bidssrc, bids_info, [(('t1w', fix_multi_T1w_source_name), 'in_file')]), @@ -329,8 +329,18 @@ def init_single_subject_wf(subject_id: str): (bidssrc, ds_report_about, [(('t1w', fix_multi_T1w_source_name), 'source_file')]), (summary, ds_report_summary, [('out_report', 'in_file')]), (about, ds_report_about, [('out_report', 'in_file')]), - ]) - # fmt:on + ]) # fmt:skip + + # Set up the template iterator once, if used + if config.workflow.level == "full": + if spaces.get_spaces(nonstandard=False, dim=(3,)): + template_iterator_wf = init_template_iterator_wf(spaces=spaces) + workflow.connect([ + (anat_fit_wf, template_iterator_wf, [ + ('outputnode.template', 'inputnode.template'), + ('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'), + ]), + ]) # fmt:skip if config.workflow.anat_only: return clean_datasinks(workflow) @@ -510,6 +520,18 @@ def init_single_subject_wf(subject_id: str): ]), ]) # fmt:skip + if config.workflow.level == "full": + workflow.connect([ + (template_iterator_wf, bold_wf, [ + ("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"), + ("outputnode.space", "inputnode.std_space"), + ("outputnode.resolution", "inputnode.std_resolution"), + ("outputnode.cohort", "inputnode.std_cohort"), + ("outputnode.std_t1w", "inputnode.std_t1w"), + ("outputnode.std_mask", "inputnode.std_mask"), + ]), + ]) # fmt:skip + return clean_datasinks(workflow) diff --git a/fmriprep/workflows/bold/base.py b/fmriprep/workflows/bold/base.py index da6052cd6..9423c9089 100644 --- a/fmriprep/workflows/bold/base.py +++ b/fmriprep/workflows/bold/base.py @@ -198,6 +198,13 @@ def init_bold_wf( "fmap_mask", "fmap_id", "sdc_method", + # Volumetric templates + "anat2std_xfm", + "std_space", + "std_resolution", + "std_cohort", + "std_t1w", + "std_mask", ], ), name="inputnode", @@ -381,6 +388,59 @@ def init_bold_wf( (bold_anat_wf, ds_bold_t1_wf, [('outputnode.bold_file', 'inputnode.bold')]), ]) # fmt:skip + if spaces.get_spaces(nonstandard=False, dim=(3,)): + # Missing: + # * Clipping BOLD after resampling + # * Resampling parcellations + bold_std_wf = init_bold_volumetric_resample_wf( + metadata=all_metadata[0], + fieldmap_id=fieldmap_id if not multiecho else None, + omp_nthreads=omp_nthreads, + name='bold_std_wf', + ) + ds_bold_std_wf = init_ds_volumes_wf( + bids_root=str(config.execution.bids_dir), + output_dir=fmriprep_dir, + multiecho=multiecho, + metadata=all_metadata[0], + name='ds_bold_std_wf', + ) + ds_bold_std_wf.inputs.inputnode.source_files = bold_series + + workflow.connect([ + (inputnode, bold_std_wf, [ + ("std_t1w", "inputnode.target_ref_file"), + ("std_mask", "inputnode.target_mask"), + ("anat2std_xfm", "inputnode.anat2std_xfm"), + ("fmap_ref", "inputnode.fmap_ref"), + ("fmap_coeff", "inputnode.fmap_coeff"), + ("fmap_id", "inputnode.fmap_id"), + ]), + (bold_fit_wf, bold_std_wf, [ + ("outputnode.coreg_boldref", "inputnode.bold_ref_file"), + ("outputnode.boldref2fmap_xfm", "inputnode.boldref2fmap_xfm"), + ("outputnode.boldref2anat_xfm", "inputnode.boldref2anat_xfm"), + ]), + (bold_native_wf, bold_std_wf, [ + ("outputnode.bold_minimal", "inputnode.bold_file"), + ("outputnode.motion_xfm", "inputnode.motion_xfm"), + ]), + (inputnode, ds_bold_std_wf, [ + ('std_t1w', 'inputnode.ref_file'), + ('anat2std_xfm', 'inputnode.anat2std_xfm'), + ('std_space', 'inputnode.space'), + ('std_resolution', 'inputnode.resolution'), + ('std_cohort', 'inputnode.cohort'), + ]), + (bold_fit_wf, ds_bold_std_wf, [ + ('outputnode.bold_mask', 'inputnode.bold_mask'), + ('outputnode.coreg_boldref', 'inputnode.bold_ref'), + ('outputnode.boldref2anat_xfm', 'inputnode.boldref2anat_xfm'), + ]), + (bold_native_wf, ds_bold_std_wf, [('outputnode.t2star_map', 'inputnode.t2star')]), + (bold_std_wf, ds_bold_std_wf, [('outputnode.bold_file', 'inputnode.bold')]), + ]) # fmt:skip + # Fill-in datasinks of reportlets seen so far for node in workflow.list_node_names(): if node.split(".")[-1].startswith("ds_report"): @@ -629,90 +689,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False): ) bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False] - if spaces.get_spaces(nonstandard=False, dim=(3,)): - # Apply transforms in 1 shot - bold_std_trans_wf = init_bold_std_trans_wf( - freesurfer=freesurfer, - mem_gb=mem_gb["resampled"], - omp_nthreads=omp_nthreads, - spaces=spaces, - multiecho=multiecho, - name="bold_std_trans_wf", - use_compression=not config.execution.low_mem, - ) - bold_std_trans_wf.inputs.inputnode.fieldwarp = "identity" - - # fmt:off - workflow.connect([ - (inputnode, bold_std_trans_wf, [ - ("template", "inputnode.templates"), - ("anat2std_xfm", "inputnode.anat2std_xfm"), - ("bold_file", "inputnode.name_source"), - ("t1w_aseg", "inputnode.bold_aseg"), - ("t1w_aparc", "inputnode.bold_aparc"), - ]), - (bold_final, bold_std_trans_wf, [ - ("mask", "inputnode.bold_mask"), - ("t2star", "inputnode.t2star"), - ]), - (bold_reg_wf, bold_std_trans_wf, [ - ("outputnode.itk_bold_to_t1", "inputnode.itk_bold_to_t1"), - ]), - (bold_std_trans_wf, outputnode, [ - ("outputnode.bold_std", "bold_std"), - ("outputnode.bold_std_ref", "bold_std_ref"), - ("outputnode.bold_mask_std", "bold_mask_std"), - ]), - ]) - # fmt:on - - if freesurfer: - # fmt:off - workflow.connect([ - (bold_std_trans_wf, func_derivatives_wf, [ - ("outputnode.bold_aseg_std", "inputnode.bold_aseg_std"), - ("outputnode.bold_aparc_std", "inputnode.bold_aparc_std"), - ]), - (bold_std_trans_wf, outputnode, [ - ("outputnode.bold_aseg_std", "bold_aseg_std"), - ("outputnode.bold_aparc_std", "bold_aparc_std"), - ]), - ]) - # fmt:on - - if not multiecho: - # fmt:off - workflow.connect([ - (bold_split, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]), - (bold_hmc_wf, bold_std_trans_wf, [ - ("outputnode.xforms", "inputnode.hmc_xforms"), - ]), - ]) - # fmt:on - else: - # fmt:off - workflow.connect([ - (split_opt_comb, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]), - (bold_std_trans_wf, outputnode, [("outputnode.t2star_std", "t2star_std")]), - ]) - # fmt:on - - # Already applied in bold_bold_trans_wf, which inputs to bold_t2s_wf - bold_std_trans_wf.inputs.inputnode.hmc_xforms = "identity" - - # fmt:off - # func_derivatives_wf internally parametrizes over snapshotted spaces. - workflow.connect([ - (bold_std_trans_wf, func_derivatives_wf, [ - ("outputnode.template", "inputnode.template"), - ("outputnode.spatial_reference", "inputnode.spatial_reference"), - ("outputnode.bold_std_ref", "inputnode.bold_std_ref"), - ("outputnode.bold_std", "inputnode.bold_std"), - ("outputnode.bold_mask_std", "inputnode.bold_mask_std"), - ]), - ]) - # fmt:on - # SURFACES ################################################################################## # Freesurfer if freesurfer and freesurfer_spaces: