Skip to content

Commit

Permalink
FIX: Missing connections, surfaces, syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
mgxd committed Oct 3, 2023
1 parent 2377ecc commit 93806c8
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 69 deletions.
2 changes: 1 addition & 1 deletion nibabies/workflows/anatomical/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .base import init_infant_anat_wf
from .base import init_infant_anat_wf, init_infant_single_anat_wf
175 changes: 139 additions & 36 deletions nibabies/workflows/anatomical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
"subject_id",
"anat2std_xfm",
"std2anat_xfm",
"t1w2fsnative_xfm",
"fsnative2t1w_xfm",
"anat2fsnative_xfm",
"fsnative2anat_xfm",
"surfaces",
"morphometrics",
"anat_aseg",
Expand All @@ -68,7 +68,7 @@ def init_infant_anat_wf(
ants_affine_init: bool,
t1w: list,
t2w: list,
anat_modality: str,
contrast: ty.Literal['T1w', 'T2w'],
bids_root: str | Path,
derivatives: Derivatives,
freesurfer: bool,
Expand Down Expand Up @@ -252,7 +252,7 @@ def init_infant_anat_wf(

# Segmentation - initial implementation should be simple: JLF
anat_seg_wf = init_anat_segmentations_wf(
anat_modality=anat_modality.capitalize(), # TODO: Revisit this option
anat_modality=contrast.capitalize(), # TODO: Revisit this option
template_dir=segmentation_atlases,
sloppy=sloppy,
omp_nthreads=omp_nthreads,
Expand Down Expand Up @@ -625,7 +625,10 @@ def init_infant_single_anat_wf(
aseg = derivatives.t2w_aseg

config.loggers.workflow.info(
f"Derivatives used (%s):\n\t<mask - %s>\n\t<aseg %s>\n\t", contrast, bool(mask), bool(aseg)
f"Derivatives used (%s):\n\t\t<Mask: %s>\n\t\t<Aseg: %s>\n",
contrast,
bool(mask),
bool(aseg),
)

inputnode = pe.Node(
Expand All @@ -635,8 +638,8 @@ def init_infant_single_anat_wf(
outputnode = pe.Node(niu.IdentityInterface(fields=ANAT_OUT_FIELDS), name="outputnode")

desc = _gen_anat_wf_desc(
t1w=anat_files if contrast == 'T1w' else None,
t2w=anat_files if contrast == 'T2w' else None,
t1w=t1w or None,
t2w=t2w or None,
mask=bool(mask),
)
workflow.__desc__ = desc.format(
Expand Down Expand Up @@ -686,31 +689,33 @@ def init_infant_single_anat_wf(
# T2-only segmentation
anat_norm_wf = init_anat_norm_wf(
sloppy=sloppy,
omp_ntheads=omp_nthreads,
omp_nthreads=omp_nthreads,
templates=spaces.get_spaces(nonstandard=False, dim=(3,)),
)
# Aggregate mask, applied mask
mask_buffer = pe.Node(
niu.IdentityInterface(fields=['anat_mask', 'anat_brain']),
name='mask_buffer',
)

if mask:
from niworkflows.interfaces.nibabel import ApplyMask

anat_template_wf.inputs.inputnode.anat_mask = mask
mask_ref = derivatives.references[f'{contrast.lower()}_mask']
anat_preproc_wf.inputnode.inputs.mask_reference = mask_ref
anat_template_wf.inputs.inputnode.mask_reference = mask_ref

apply_deriv_mask = pe.Node(ApplyMask(), name='apply_deriv_mask')
# fmt:off
workflow.connect([
(anat_preproc_wf, anat_norm_wf, [
('outputnode.anat_mask', 'inputnode.moving_mask')]),
(anat_template_wf, mask_buffer, [
('outputnode.anat_mask', 'anat_mask')]),
(anat_preproc_wf, apply_deriv_mask, [
('outputnode.anat_preproc', 'in_file')]),
(anat_template_wf, apply_deriv_mask, [
('outputnode.anat_mask', 'in_mask')]),
(apply_deriv_mask, anat_seg_wf, [
("out_mask", "inputnode.anat_brain")]),
(anat_template_wf, anat_derivatives_wf, [
('outputnode.anat_mask', 'inputnode.anat_mask')]),
(anat_template_wf, outputnode, [
('outputnode.anat_mask', 'anat_mask')]),
(apply_deriv_mask, mask_buffer, [
('out_file', 'anat_brain')]),
])
# fmt:on

Expand All @@ -726,18 +731,28 @@ def init_infant_single_anat_wf(
)
# fmt:off
workflow.connect([
(anat_preproc_wf, brain_extraction_wf, [('outputnode.anat_preproc', 'inputnode.t2w_preproc')]),
(brain_extraction_wf, anat_seg_wf, [('outputnode.t2w_brain', 'inputnode.anat_brain')]),
(brain_extraction_wf, anat_norm_wf, [('outputnode.out_mask', 'inputnode.moving_mask')]),
(brain_extraction_wf, anat_derivatives_wf, [('outputnode.out_mask', 'inputnode.anat_mask')]),
(anat_preproc_wf, brain_extraction_wf, [
('outputnode.anat_preproc', 'inputnode.t2w_preproc')]),
(brain_extraction_wf, mask_buffer, [
('outputnode.t2w_brain', 'anat_brain'),
('outputnode.out_mask', 'anat_mask')]),
])
# fmt:on

if aseg:
anat_template_wf.inputs.inputnode.anat_aseg = aseg
aseg_ref = derivatives.references[f'{contrast.lower()}_aseg']
anat_template_wf.inputs.inputnode.aseg_reference = aseg_ref

workflow.connect(
anat_template_wf, 'outputnode.anat_aseg', anat_seg_wf, 'inputnode.anat_aseg'
)

# fmt:off
workflow.connect([
(inputnode, anat_template_wf, [("anat_file", "inputnode.anat_files")]),
(inputnode, anat_reports_wf, [("anat_file", "inputnode.source_file")]),
(inputnode, anat_norm_wf, [(("anat_file", fix_multi_source_name), "inputnode.orig_t1w")]),
(inputnode, anat_template_wf, [(contrast.lower(), "inputnode.anat_files")]),
(inputnode, anat_reports_wf, [(contrast.lower(), "inputnode.source_file")]),
(inputnode, anat_norm_wf, [((contrast.lower(), fix_multi_source_name), "inputnode.orig_t1w")]),

(anat_template_wf, outputnode, [
("outputnode.anat_realign_xfm", "anat_ref_xfms")]),
Expand All @@ -750,19 +765,26 @@ def init_infant_single_anat_wf(
("outputnode.out_report", "inputnode.anat_conform_report")]),
(anat_preproc_wf, anat_norm_wf, [
('outputnode.anat_preproc', 'inputnode.moving_image')]),
(anat_preproc_wf, outputnode, [
('outputnode.anat_preproc', 'anat_preproc')]),
(anat_preproc_wf, anat_derivatives_wf, [
('outputnode.anat_preproc', f'inputnode.{contrast.lower()}_preproc')]),
(mask_buffer, anat_derivatives_wf, [
('anat_mask', 'inputnode.anat_mask')]),
(mask_buffer, outputnode, [
('anat_mask', 'anat_mask')]),
(mask_buffer, anat_seg_wf, [('anat_brain', 'inputnode.anat_brain')]),
(anat_seg_wf, outputnode, [
("outputnode.anat_dseg", "anat_dseg"),
("outputnode.anat_tpms", "anat_tpms")]),
(anat_seg_wf, anat_derivatives_wf, [
("outputnode.anat_dseg", "inputnode.anat_dseg"),
("outputnode.anat_tpms", "inputnode.anat_tpms"),
]),
("outputnode.anat_tpms", "inputnode.anat_tpms")]),
(mask_buffer, anat_norm_wf, [
('anat_mask', 'inputnode.moving_mask')]),
(anat_seg_wf, anat_norm_wf, [
("outputnode.anat_dseg", "inputnode.moving_segmentation"),
("outputnode.anat_tpms", "inputnode.moving_tpms")]),

(anat_norm_wf, anat_reports_wf, [("poutputnode.template", "inputnode.template")]),
(anat_norm_wf, outputnode, [
("poutputnode.standardized", "std_preproc"),
Expand All @@ -777,7 +799,7 @@ def init_infant_single_anat_wf(
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
("outputnode.std2anat_xfm", "inputnode.std2anat_xfm")]),
(outputnode, anat_reports_wf, [
("anat_preproc", "inputnode.t1w_preproc"),
("anat_preproc", "inputnode.anat_preproc"),
("anat_mask", "inputnode.anat_mask"),
("anat_dseg", "inputnode.anat_dseg"),
("std_preproc", "inputnode.std_t1w"),
Expand Down Expand Up @@ -828,17 +850,98 @@ def init_infant_single_anat_wf(
else:
# TODO: Use MCRIBS segmentation
...
if mask:
workflow.connect(
anat_template_wf, 'outputnode.anat_mask', surface_recon_wf, 'inputnode.anat_mask'
)
else:
workflow.connect(
brain_extraction_wf, 'outputnode.out_mask', surface_recon_wf, 'inputnode.anat_mask'
)
else:
raise NotImplementedError

# Anatomical ribbon file using HCP signed-distance volume method
anat_ribbon_wf = init_anat_ribbon_wf()

# fmt:off
workflow.connect([
(inputnode, surface_recon_wf, [
("subject_id", "inputnode.subject_id"),
("subjects_dir", "inputnode.subjects_dir")]),
(anat_template_wf, surface_recon_wf, [
("outputnode.anat_ref", "inputnode.t1w"),
]),
(mask_buffer, surface_recon_wf, [
("anat_brain", "inputnode.skullstripped_t1"),
("anat_mask", "inputnode.anat_mask")]),
(anat_preproc_wf, surface_recon_wf, [
("outputnode.anat_preproc", "inputnode.corrected_t1")]),
(surface_recon_wf, outputnode, [
("outputnode.subjects_dir", "subjects_dir"),
("outputnode.subject_id", "subject_id"),
("outputnode.t1w2fsnative_xfm", "anat2fsnative_xfm"),
("outputnode.fsnative2t1w_xfm", "fsnative2anat_xfm"),
("outputnode.surfaces", "surfaces"),
("outputnode.morphometrics", "morphometrics"),
("outputnode.out_aparc", "anat_aparc"),
("outputnode.out_aseg", "anat_aseg"),
]),
(mask_buffer, anat_ribbon_wf, [
("anat_mask", "inputnode.t1w_mask"),
]),
(surface_recon_wf, anat_ribbon_wf, [
("outputnode.surfaces", "inputnode.surfaces"),
]),
(anat_ribbon_wf, outputnode, [
("outputnode.anat_ribbon", "anat_ribbon")
]),
(anat_ribbon_wf, anat_derivatives_wf, [
("outputnode.anat_ribbon", "inputnode.anat_ribbon"),
]),
(surface_recon_wf, sphere_reg_wf, [
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
]),
(surface_recon_wf, anat_reports_wf, [
("outputnode.subject_id", "inputnode.subject_id"),
("outputnode.subjects_dir", "inputnode.subjects_dir"),
]),
(surface_recon_wf, anat_derivatives_wf, [
("outputnode.out_aseg", "inputnode.anat_fs_aseg"),
("outputnode.out_aparc", "inputnode.anat_fs_aparc"),
("outputnode.t1w2fsnative_xfm", "inputnode.anat2fsnative_xfm"),
("outputnode.fsnative2t1w_xfm", "inputnode.fsnative2anat_xfm"),
("outputnode.surfaces", "inputnode.surfaces"),
("outputnode.morphometrics", "inputnode.morphometrics"),
]),
(sphere_reg_wf, outputnode, [
('outputnode.sphere_reg', 'sphere_reg'),
('outputnode.sphere_reg_fsLR', 'sphere_reg_fsLR')]),
(sphere_reg_wf, anat_derivatives_wf, [
('outputnode.sphere_reg', 'inputnode.sphere_reg'),
('outputnode.sphere_reg_fsLR', 'inputnode.sphere_reg_fsLR')]),
])
# fmt: on

if cifti_output:
from nibabies.workflows.anatomical.resampling import (
init_anat_fsLR_resampling_wf,
)

is_mcribs = recon_method == "mcribs"
# handles morph_grayords_wf
anat_fsLR_resampling_wf = init_anat_fsLR_resampling_wf(cifti_output, mcribs=is_mcribs)
anat_derivatives_wf.get_node('inputnode').inputs.cifti_density = cifti_output
# fmt:off
workflow.connect([
(sphere_reg_wf, anat_fsLR_resampling_wf, [
('outputnode.sphere_reg', 'inputnode.sphere_reg'),
('outputnode.sphere_reg_fsLR', 'inputnode.sphere_reg_fsLR')]),
(surface_recon_wf, anat_fsLR_resampling_wf, [
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.surfaces', 'inputnode.surfaces'),
('outputnode.morphometrics', 'inputnode.morphometrics')]),
(anat_fsLR_resampling_wf, anat_derivatives_wf, [
("outputnode.cifti_morph", "inputnode.cifti_morph"),
("outputnode.cifti_metadata", "inputnode.cifti_metadata")]),
(anat_fsLR_resampling_wf, outputnode, [
("outputnode.midthickness_fsLR", "midthickness_fsLR")])
])
# fmt:on
return workflow


Expand Down
42 changes: 10 additions & 32 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def init_single_subject_wf(

anat_only = config.workflow.anat_only
derivatives = Derivatives(bids_root=config.execution.layout.root)
anat_modality = "t1w" if subject_data["t1w"] else "t2w"
contrast = "T1w" if subject_data["t1w"] else "T2w"
single_modality = not (subject_data['t1w'] and subject_data['t2w'])
# Make sure we always go through these two checks
if not anat_only and not subject_data["bold"]:
task_id = config.execution.task_id
Expand Down Expand Up @@ -347,7 +348,7 @@ def init_single_subject_wf(
wf_args = dict(
ants_affine_init=True,
age_months=age,
anat_modality=anat_modality,
contrast=contrast,
t1w=subject_data["t1w"],
t2w=subject_data["t2w"],
bids_root=config.execution.bids_dir,
Expand All @@ -364,33 +365,10 @@ def init_single_subject_wf(
spaces=spaces,
cifti_output=config.workflow.cifti_output,
)

if subject_data['t1w'] and subject_data['t2w']:
anat_preproc_wf = init_infant_anat_wf(**wf_args)
else:
anat_preproc_wf = init_infant_single_anat_wf(
contrast='T1w' if subject_data['t1w'] else 'T2w', **wf_args
)
# Preprocessing of anatomical (includes registration to UNCInfant)
anat_preproc_wf = init_infant_anat_wf(
ants_affine_init=True,
age_months=age,
anat_modality=anat_modality,
t1w=subject_data["t1w"],
t2w=subject_data["t2w"],
bids_root=config.execution.bids_dir,
derivatives=derivatives,
freesurfer=config.workflow.run_reconall,
hires=config.workflow.hires,
longitudinal=config.workflow.longitudinal,
omp_nthreads=config.nipype.omp_nthreads,
output_dir=nibabies_dir,
segmentation_atlases=config.execution.segmentation_atlases_dir,
skull_strip_mode=config.workflow.skull_strip_t1w,
skull_strip_template=Reference.from_string(config.workflow.skull_strip_template)[0],
sloppy=config.execution.sloppy,
spaces=spaces,
cifti_output=config.workflow.cifti_output,
anat_preproc_wf = (
init_infant_anat_wf(**wf_args)
if not single_modality
else init_infant_single_anat_wf(**wf_args)
)

# fmt: off
Expand Down Expand Up @@ -424,17 +402,17 @@ def init_single_subject_wf(

workflow.connect([
(bidssrc, bids_info, [
(('t1w', fix_multi_source_name), 'in_file'),
((contrast.lower(), fix_multi_source_name), 'in_file'),
]),
(bidssrc, summary, [
('t1w', 't1w'),
('t2w', 't2w'),
]),
(bidssrc, ds_report_summary, [
(('t1w', fix_multi_source_name), 'source_file'),
((contrast.lower(), fix_multi_source_name), 'source_file'),
]),
(bidssrc, ds_report_about, [
(('t1w', fix_multi_source_name), 'source_file'),
((contrast.lower(), fix_multi_source_name), 'source_file'),
]),
])
# fmt: on
Expand Down

0 comments on commit 93806c8

Please sign in to comment.