diff --git a/nibabies/workflows/anatomical/__init__.py b/nibabies/workflows/anatomical/__init__.py index 33d993b2..021d133d 100644 --- a/nibabies/workflows/anatomical/__init__.py +++ b/nibabies/workflows/anatomical/__init__.py @@ -1 +1 @@ -from .base import init_infant_anat_wf +from .base import init_infant_anat_wf, init_infant_single_anat_wf diff --git a/nibabies/workflows/anatomical/base.py b/nibabies/workflows/anatomical/base.py index 1a226e59..ca3d71df 100644 --- a/nibabies/workflows/anatomical/base.py +++ b/nibabies/workflows/anatomical/base.py @@ -46,8 +46,8 @@ "subject_id", "anat2std_xfm", "std2anat_xfm", - "t1w2fsnative_xfm", - "fsnative2t1w_xfm", + "anat2fsnative_xfm", + "fsnative2anat_xfm", "surfaces", "morphometrics", "anat_aseg", @@ -68,7 +68,7 @@ def init_infant_anat_wf( ants_affine_init: bool, t1w: list, t2w: list, - anat_modality: str, + contrast: ty.Literal['T1w', 'T2w'], bids_root: str | Path, derivatives: Derivatives, freesurfer: bool, @@ -252,7 +252,7 @@ def init_infant_anat_wf( # Segmentation - initial implementation should be simple: JLF anat_seg_wf = init_anat_segmentations_wf( - anat_modality=anat_modality.capitalize(), # TODO: Revisit this option + anat_modality=contrast.capitalize(), # TODO: Revisit this option template_dir=segmentation_atlases, sloppy=sloppy, omp_nthreads=omp_nthreads, @@ -625,7 +625,10 @@ def init_infant_single_anat_wf( aseg = derivatives.t2w_aseg config.loggers.workflow.info( - f"Derivatives used (%s):\n\t\n\t\n\t", contrast, bool(mask), bool(aseg) + f"Derivatives used (%s):\n\t\t\n\t\t\n", + contrast, + bool(mask), + bool(aseg), ) inputnode = pe.Node( @@ -635,8 +638,8 @@ def init_infant_single_anat_wf( outputnode = pe.Node(niu.IdentityInterface(fields=ANAT_OUT_FIELDS), name="outputnode") desc = _gen_anat_wf_desc( - t1w=anat_files if contrast == 'T1w' else None, - t2w=anat_files if contrast == 'T2w' else None, + t1w=t1w or None, + t2w=t2w or None, mask=bool(mask), ) workflow.__desc__ = desc.format( @@ -686,31 +689,33 @@ def init_infant_single_anat_wf( # T2-only segmentation anat_norm_wf = init_anat_norm_wf( sloppy=sloppy, - omp_ntheads=omp_nthreads, + omp_nthreads=omp_nthreads, templates=spaces.get_spaces(nonstandard=False, dim=(3,)), ) + # Aggregate mask, applied mask + mask_buffer = pe.Node( + niu.IdentityInterface(fields=['anat_mask', 'anat_brain']), + name='mask_buffer', + ) if mask: from niworkflows.interfaces.nibabel import ApplyMask + anat_template_wf.inputs.inputnode.anat_mask = mask mask_ref = derivatives.references[f'{contrast.lower()}_mask'] - anat_preproc_wf.inputnode.inputs.mask_reference = mask_ref + anat_template_wf.inputs.inputnode.mask_reference = mask_ref apply_deriv_mask = pe.Node(ApplyMask(), name='apply_deriv_mask') # fmt:off workflow.connect([ - (anat_preproc_wf, anat_norm_wf, [ - ('outputnode.anat_mask', 'inputnode.moving_mask')]), + (anat_template_wf, mask_buffer, [ + ('outputnode.anat_mask', 'anat_mask')]), (anat_preproc_wf, apply_deriv_mask, [ ('outputnode.anat_preproc', 'in_file')]), (anat_template_wf, apply_deriv_mask, [ ('outputnode.anat_mask', 'in_mask')]), - (apply_deriv_mask, anat_seg_wf, [ - ("out_mask", "inputnode.anat_brain")]), - (anat_template_wf, anat_derivatives_wf, [ - ('outputnode.anat_mask', 'inputnode.anat_mask')]), - (anat_template_wf, outputnode, [ - ('outputnode.anat_mask', 'anat_mask')]), + (apply_deriv_mask, mask_buffer, [ + ('out_file', 'anat_brain')]), ]) # fmt:on @@ -726,18 +731,28 @@ def init_infant_single_anat_wf( ) # fmt:off workflow.connect([ - (anat_preproc_wf, brain_extraction_wf, [('outputnode.anat_preproc', 'inputnode.t2w_preproc')]), - (brain_extraction_wf, anat_seg_wf, [('outputnode.t2w_brain', 'inputnode.anat_brain')]), - (brain_extraction_wf, anat_norm_wf, [('outputnode.out_mask', 'inputnode.moving_mask')]), - (brain_extraction_wf, anat_derivatives_wf, [('outputnode.out_mask', 'inputnode.anat_mask')]), + (anat_preproc_wf, brain_extraction_wf, [ + ('outputnode.anat_preproc', 'inputnode.t2w_preproc')]), + (brain_extraction_wf, mask_buffer, [ + ('outputnode.t2w_brain', 'anat_brain'), + ('outputnode.out_mask', 'anat_mask')]), ]) # fmt:on + if aseg: + anat_template_wf.inputs.inputnode.anat_aseg = aseg + aseg_ref = derivatives.references[f'{contrast.lower()}_aseg'] + anat_template_wf.inputs.inputnode.aseg_reference = aseg_ref + + workflow.connect( + anat_template_wf, 'outputnode.anat_aseg', anat_seg_wf, 'inputnode.anat_aseg' + ) + # fmt:off workflow.connect([ - (inputnode, anat_template_wf, [("anat_file", "inputnode.anat_files")]), - (inputnode, anat_reports_wf, [("anat_file", "inputnode.source_file")]), - (inputnode, anat_norm_wf, [(("anat_file", fix_multi_source_name), "inputnode.orig_t1w")]), + (inputnode, anat_template_wf, [(contrast.lower(), "inputnode.anat_files")]), + (inputnode, anat_reports_wf, [(contrast.lower(), "inputnode.source_file")]), + (inputnode, anat_norm_wf, [((contrast.lower(), fix_multi_source_name), "inputnode.orig_t1w")]), (anat_template_wf, outputnode, [ ("outputnode.anat_realign_xfm", "anat_ref_xfms")]), @@ -750,19 +765,26 @@ def init_infant_single_anat_wf( ("outputnode.out_report", "inputnode.anat_conform_report")]), (anat_preproc_wf, anat_norm_wf, [ ('outputnode.anat_preproc', 'inputnode.moving_image')]), + (anat_preproc_wf, outputnode, [ + ('outputnode.anat_preproc', 'anat_preproc')]), (anat_preproc_wf, anat_derivatives_wf, [ ('outputnode.anat_preproc', f'inputnode.{contrast.lower()}_preproc')]), + (mask_buffer, anat_derivatives_wf, [ + ('anat_mask', 'inputnode.anat_mask')]), + (mask_buffer, outputnode, [ + ('anat_mask', 'anat_mask')]), + (mask_buffer, anat_seg_wf, [('anat_brain', 'inputnode.anat_brain')]), (anat_seg_wf, outputnode, [ ("outputnode.anat_dseg", "anat_dseg"), ("outputnode.anat_tpms", "anat_tpms")]), (anat_seg_wf, anat_derivatives_wf, [ ("outputnode.anat_dseg", "inputnode.anat_dseg"), - ("outputnode.anat_tpms", "inputnode.anat_tpms"), - ]), + ("outputnode.anat_tpms", "inputnode.anat_tpms")]), + (mask_buffer, anat_norm_wf, [ + ('anat_mask', 'inputnode.moving_mask')]), (anat_seg_wf, anat_norm_wf, [ ("outputnode.anat_dseg", "inputnode.moving_segmentation"), ("outputnode.anat_tpms", "inputnode.moving_tpms")]), - (anat_norm_wf, anat_reports_wf, [("poutputnode.template", "inputnode.template")]), (anat_norm_wf, outputnode, [ ("poutputnode.standardized", "std_preproc"), @@ -777,7 +799,7 @@ def init_infant_single_anat_wf( ("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"), ("outputnode.std2anat_xfm", "inputnode.std2anat_xfm")]), (outputnode, anat_reports_wf, [ - ("anat_preproc", "inputnode.t1w_preproc"), + ("anat_preproc", "inputnode.anat_preproc"), ("anat_mask", "inputnode.anat_mask"), ("anat_dseg", "inputnode.anat_dseg"), ("std_preproc", "inputnode.std_t1w"), @@ -828,17 +850,98 @@ def init_infant_single_anat_wf( else: # TODO: Use MCRIBS segmentation ... - if mask: - workflow.connect( - anat_template_wf, 'outputnode.anat_mask', surface_recon_wf, 'inputnode.anat_mask' - ) - else: - workflow.connect( - brain_extraction_wf, 'outputnode.out_mask', surface_recon_wf, 'inputnode.anat_mask' - ) else: raise NotImplementedError + # Anatomical ribbon file using HCP signed-distance volume method + anat_ribbon_wf = init_anat_ribbon_wf() + + # fmt:off + workflow.connect([ + (inputnode, surface_recon_wf, [ + ("subject_id", "inputnode.subject_id"), + ("subjects_dir", "inputnode.subjects_dir")]), + (anat_template_wf, surface_recon_wf, [ + ("outputnode.anat_ref", "inputnode.t1w"), + ]), + (mask_buffer, surface_recon_wf, [ + ("anat_brain", "inputnode.skullstripped_t1"), + ("anat_mask", "inputnode.anat_mask")]), + (anat_preproc_wf, surface_recon_wf, [ + ("outputnode.anat_preproc", "inputnode.corrected_t1")]), + (surface_recon_wf, outputnode, [ + ("outputnode.subjects_dir", "subjects_dir"), + ("outputnode.subject_id", "subject_id"), + ("outputnode.t1w2fsnative_xfm", "anat2fsnative_xfm"), + ("outputnode.fsnative2t1w_xfm", "fsnative2anat_xfm"), + ("outputnode.surfaces", "surfaces"), + ("outputnode.morphometrics", "morphometrics"), + ("outputnode.out_aparc", "anat_aparc"), + ("outputnode.out_aseg", "anat_aseg"), + ]), + (mask_buffer, anat_ribbon_wf, [ + ("anat_mask", "inputnode.t1w_mask"), + ]), + (surface_recon_wf, anat_ribbon_wf, [ + ("outputnode.surfaces", "inputnode.surfaces"), + ]), + (anat_ribbon_wf, outputnode, [ + ("outputnode.anat_ribbon", "anat_ribbon") + ]), + (anat_ribbon_wf, anat_derivatives_wf, [ + ("outputnode.anat_ribbon", "inputnode.anat_ribbon"), + ]), + (surface_recon_wf, sphere_reg_wf, [ + ('outputnode.subject_id', 'inputnode.subject_id'), + ('outputnode.subjects_dir', 'inputnode.subjects_dir'), + ]), + (surface_recon_wf, anat_reports_wf, [ + ("outputnode.subject_id", "inputnode.subject_id"), + ("outputnode.subjects_dir", "inputnode.subjects_dir"), + ]), + (surface_recon_wf, anat_derivatives_wf, [ + ("outputnode.out_aseg", "inputnode.anat_fs_aseg"), + ("outputnode.out_aparc", "inputnode.anat_fs_aparc"), + ("outputnode.t1w2fsnative_xfm", "inputnode.anat2fsnative_xfm"), + ("outputnode.fsnative2t1w_xfm", "inputnode.fsnative2anat_xfm"), + ("outputnode.surfaces", "inputnode.surfaces"), + ("outputnode.morphometrics", "inputnode.morphometrics"), + ]), + (sphere_reg_wf, outputnode, [ + ('outputnode.sphere_reg', 'sphere_reg'), + ('outputnode.sphere_reg_fsLR', 'sphere_reg_fsLR')]), + (sphere_reg_wf, anat_derivatives_wf, [ + ('outputnode.sphere_reg', 'inputnode.sphere_reg'), + ('outputnode.sphere_reg_fsLR', 'inputnode.sphere_reg_fsLR')]), + ]) + # fmt: on + + if cifti_output: + from nibabies.workflows.anatomical.resampling import ( + init_anat_fsLR_resampling_wf, + ) + + is_mcribs = recon_method == "mcribs" + # handles morph_grayords_wf + anat_fsLR_resampling_wf = init_anat_fsLR_resampling_wf(cifti_output, mcribs=is_mcribs) + anat_derivatives_wf.get_node('inputnode').inputs.cifti_density = cifti_output + # fmt:off + workflow.connect([ + (sphere_reg_wf, anat_fsLR_resampling_wf, [ + ('outputnode.sphere_reg', 'inputnode.sphere_reg'), + ('outputnode.sphere_reg_fsLR', 'inputnode.sphere_reg_fsLR')]), + (surface_recon_wf, anat_fsLR_resampling_wf, [ + ('outputnode.subject_id', 'inputnode.subject_id'), + ('outputnode.subjects_dir', 'inputnode.subjects_dir'), + ('outputnode.surfaces', 'inputnode.surfaces'), + ('outputnode.morphometrics', 'inputnode.morphometrics')]), + (anat_fsLR_resampling_wf, anat_derivatives_wf, [ + ("outputnode.cifti_morph", "inputnode.cifti_morph"), + ("outputnode.cifti_metadata", "inputnode.cifti_metadata")]), + (anat_fsLR_resampling_wf, outputnode, [ + ("outputnode.midthickness_fsLR", "midthickness_fsLR")]) + ]) + # fmt:on return workflow diff --git a/nibabies/workflows/base.py b/nibabies/workflows/base.py index 10b39499..9f5758ab 100644 --- a/nibabies/workflows/base.py +++ b/nibabies/workflows/base.py @@ -229,7 +229,8 @@ def init_single_subject_wf( anat_only = config.workflow.anat_only derivatives = Derivatives(bids_root=config.execution.layout.root) - anat_modality = "t1w" if subject_data["t1w"] else "t2w" + contrast = "T1w" if subject_data["t1w"] else "T2w" + single_modality = not (subject_data['t1w'] and subject_data['t2w']) # Make sure we always go through these two checks if not anat_only and not subject_data["bold"]: task_id = config.execution.task_id @@ -347,7 +348,7 @@ def init_single_subject_wf( wf_args = dict( ants_affine_init=True, age_months=age, - anat_modality=anat_modality, + contrast=contrast, t1w=subject_data["t1w"], t2w=subject_data["t2w"], bids_root=config.execution.bids_dir, @@ -364,33 +365,10 @@ def init_single_subject_wf( spaces=spaces, cifti_output=config.workflow.cifti_output, ) - - if subject_data['t1w'] and subject_data['t2w']: - anat_preproc_wf = init_infant_anat_wf(**wf_args) - else: - anat_preproc_wf = init_infant_single_anat_wf( - contrast='T1w' if subject_data['t1w'] else 'T2w', **wf_args - ) - # Preprocessing of anatomical (includes registration to UNCInfant) - anat_preproc_wf = init_infant_anat_wf( - ants_affine_init=True, - age_months=age, - anat_modality=anat_modality, - t1w=subject_data["t1w"], - t2w=subject_data["t2w"], - bids_root=config.execution.bids_dir, - derivatives=derivatives, - freesurfer=config.workflow.run_reconall, - hires=config.workflow.hires, - longitudinal=config.workflow.longitudinal, - omp_nthreads=config.nipype.omp_nthreads, - output_dir=nibabies_dir, - segmentation_atlases=config.execution.segmentation_atlases_dir, - skull_strip_mode=config.workflow.skull_strip_t1w, - skull_strip_template=Reference.from_string(config.workflow.skull_strip_template)[0], - sloppy=config.execution.sloppy, - spaces=spaces, - cifti_output=config.workflow.cifti_output, + anat_preproc_wf = ( + init_infant_anat_wf(**wf_args) + if not single_modality + else init_infant_single_anat_wf(**wf_args) ) # fmt: off @@ -424,17 +402,17 @@ def init_single_subject_wf( workflow.connect([ (bidssrc, bids_info, [ - (('t1w', fix_multi_source_name), 'in_file'), + ((contrast.lower(), fix_multi_source_name), 'in_file'), ]), (bidssrc, summary, [ ('t1w', 't1w'), ('t2w', 't2w'), ]), (bidssrc, ds_report_summary, [ - (('t1w', fix_multi_source_name), 'source_file'), + ((contrast.lower(), fix_multi_source_name), 'source_file'), ]), (bidssrc, ds_report_about, [ - (('t1w', fix_multi_source_name), 'source_file'), + ((contrast.lower(), fix_multi_source_name), 'source_file'), ]), ]) # fmt: on