diff --git a/nibabies/workflows/anatomical/fit.py b/nibabies/workflows/anatomical/fit.py index 4466575a..d0fe552d 100644 --- a/nibabies/workflows/anatomical/fit.py +++ b/nibabies/workflows/anatomical/fit.py @@ -36,7 +36,7 @@ from nibabies import config from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf from nibabies.workflows.anatomical.outputs import init_anat_reports_wf -from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf +from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf from nibabies.workflows.anatomical.registration import init_coregistration_wf from nibabies.workflows.anatomical.segmentation import init_segmentation_wf from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf @@ -184,11 +184,13 @@ def init_infant_anat_fit_wf( name='anat_buffer', ) - # Additional CSF normalization, if necessary - anat_norm_buffer = pe.Node( + # Additional buffer if CSF normalization is used + anat_preproc_buffer = pe.Node( niu.IdentityInterface(fields=['anat_preproc']), - name='anat_norm_buffer', + name='anat_preproc_buffer', ) + if not config.workflow.norm_csf: + workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc') if reference_anat == 'T1w': LOGGER.info('ANAT: Using T1w as the reference anatomical') @@ -254,7 +256,7 @@ def init_infant_anat_fit_wf( msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer') workflow.connect([ - (anat_buffer, outputnode, [ + (anat_preproc_buffer, outputnode, [ ('anat_preproc', 'anat_preproc'), ]), (refined_buffer, outputnode, [ @@ -886,6 +888,15 @@ def init_infant_anat_fit_wf( anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()] std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()] + if config.workflow.norm_csf: + csf_norm_wf = init_csf_norm_wf() + + workflow.connect([ + (anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]), + (seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]), + (csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]), + ]) # fmt:skip + if templates: LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}') register_template_wf = init_register_template_wf( @@ -901,7 +912,9 @@ def init_infant_anat_fit_wf( workflow.connect([ (inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]), - (anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]), + (anat_preproc_buffer, register_template_wf, [ + ('anat_preproc', 'inputnode.moving_image'), + ]), (refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]), (sourcefile_buffer, ds_template_registration_wf, [ ('anat_source_files', 'inputnode.source_files') @@ -1094,7 +1107,7 @@ def init_infant_anat_fit_wf( (seg_buffer, refinement_wf, [ ('ants_segs', 'inputnode.ants_segs'), # TODO: Verify this is the same as dseg ]), - (anat_buffer, applyrefined, [('anat_preproc', 'in_file')]), + (anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]), (refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]), (refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]), (applyrefined, refined_buffer, [('out_file', 'anat_brain')]), @@ -1372,6 +1385,14 @@ def init_infant_single_anat_fit_wf( name='anat_buffer', ) + # Additional buffer if CSF normalization is used + anat_preproc_buffer = pe.Node( + niu.IdentityInterface(fields=['anat_preproc']), + name='anat_preproc_buffer', + ) + if not config.workflow.norm_csf: + workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc') + aseg_buffer = pe.Node( niu.IdentityInterface(fields=['anat_aseg']), name='aseg_buffer', @@ -1411,7 +1432,7 @@ def init_infant_single_anat_fit_wf( msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer') workflow.connect([ - (anat_buffer, outputnode, [ + (anat_preproc_buffer, outputnode, [ ('anat_preproc', 'anat_preproc'), ]), (refined_buffer, outputnode, [ @@ -1712,6 +1733,15 @@ def init_infant_single_anat_fit_wf( anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()] std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()] + if config.workflow.norm_csf: + csf_norm_wf = init_csf_norm_wf() + + workflow.connect([ + (anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]), + (seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]), + (csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]), + ]) # fmt:skip + if templates: LOGGER.info(f'ANAT Stage 4: Preparing normalization workflow for {templates}') register_template_wf = init_register_template_wf( @@ -1727,7 +1757,9 @@ def init_infant_single_anat_fit_wf( workflow.connect([ (inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]), - (anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]), + (anat_preproc_buffer, register_template_wf, [ + ('anat_preproc', 'inputnode.moving_image'), + ]), (refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]), (sourcefile_buffer, ds_template_registration_wf, [ ('anat_source_files', 'inputnode.source_files') @@ -1909,7 +1941,7 @@ def init_infant_single_anat_fit_wf( (seg_buffer, refinement_wf, [ ('ants_segs', 'inputnode.ants_segs'), ]), - (anat_buffer, applyrefined, [('anat_preproc', 'in_file')]), + (anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]), (refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]), (refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]), (applyrefined, refined_buffer, [('out_file', 'anat_brain')]),