diff --git a/src/fmripost_aroma/workflows/base.py b/src/fmripost_aroma/workflows/base.py index ed87dfd..74d2e47 100644 --- a/src/fmripost_aroma/workflows/base.py +++ b/src/fmripost_aroma/workflows/base.py @@ -349,11 +349,13 @@ def init_single_subject_wf(subject_id: str): ica_aroma_wf.inputs.inputnode.bold_mask_std = functional_cache['bold_mask_std'] workflow.add_nodes([ica_aroma_wf]) - functional_cache['skip_vols'] = ( - config.workflow.dummy_scans or functional_cache['skip_vols'] - ) + if config.workflow.dummy_scans is not None: + skip_vols = config.workflow.dummy_scans + else: + skip_vols = get_nss(functional_cache['confounds']) + ica_aroma_wf.inputs.inputnode.confounds = functional_cache['confounds'] - ica_aroma_wf.inputs.inputnode.skip_vols = functional_cache['skip_vols'] + ica_aroma_wf.inputs.inputnode.skip_vols = skip_vols if config.workflow.denoise_method: for space in spaces: @@ -393,3 +395,24 @@ def clean_datasinks(workflow: pe.Workflow) -> pe.Workflow: if node.split('.')[-1].startswith('ds_'): workflow.get_node(node).interface.out_path_base = '' return workflow + + +def get_nss(confounds_file): + """Get number of non-steady state volumes.""" + import numpy as np + import pandas as pd + + df = pd.read_table(confounds_file) + + nss_cols = [c for c in df.columns if c.startswith("non_steady_state_outlier")] + + dummy_scans = 0 + if nss_cols: + initial_volumes_df = df[nss_cols] + dummy_scans = np.any(initial_volumes_df.to_numpy(), axis=1) + dummy_scans = np.where(dummy_scans)[0] + + # reasonably assumes all NSS volumes are contiguous + dummy_scans = int(dummy_scans[-1] + 1) + + return dummy_scans