From 996111e3d92515482a0747bc65ece04706fb20fd Mon Sep 17 00:00:00 2001 From: Gab-D-G Date: Tue, 13 Feb 2024 13:49:54 -0500 Subject: [PATCH] Implemented a --group_avg_prior to compute prior networks from the group average. --- .../diagnosis_pkg/diagnosis_wf.py | 1 + .../analysis_pkg/diagnosis_pkg/interfaces.py | 47 ++++++++++++++----- rabies/parser.py | 9 ++++ scripts/error_check_rabies.py | 2 +- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/rabies/analysis_pkg/diagnosis_pkg/diagnosis_wf.py b/rabies/analysis_pkg/diagnosis_pkg/diagnosis_wf.py index c30b7ad..670d11e 100644 --- a/rabies/analysis_pkg/diagnosis_pkg/diagnosis_wf.py +++ b/rabies/analysis_pkg/diagnosis_pkg/diagnosis_wf.py @@ -125,6 +125,7 @@ def prep_scan_data(dict_file, spatial_info, temporal_info): DatasetDiagnosis_node.inputs.scan_QC_thresholds = analysis_opts.scan_QC_thresholds DatasetDiagnosis_node.inputs.figure_format = analysis_opts.figure_format DatasetDiagnosis_node.inputs.extended_QC = analysis_opts.extended_QC + DatasetDiagnosis_node.inputs.group_avg_prior = analysis_opts.group_avg_prior workflow.connect([ (inputnode, prep_scan_data_node, [ diff --git a/rabies/analysis_pkg/diagnosis_pkg/interfaces.py b/rabies/analysis_pkg/diagnosis_pkg/interfaces.py index 0f23692..327a543 100644 --- a/rabies/analysis_pkg/diagnosis_pkg/interfaces.py +++ b/rabies/analysis_pkg/diagnosis_pkg/interfaces.py @@ -161,6 +161,8 @@ class DatasetDiagnosisInputSpec(BaseInterfaceInputSpec): desc="Whether network maps are absolute or relative.") scan_QC_thresholds = traits.Dict( desc="Specifications for scan-level QC thresholds.") + group_avg_prior = traits.Bool( + desc="Whether to use the group average (median) as a network prior instead of using an external image.") figure_format = traits.Str( desc="Select file format for figures.") extended_QC = traits.Bool( @@ -368,10 +370,16 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors): scan_QC_thresholds = self.inputs.scan_QC_thresholds - prior_maps = scan_data['prior_maps'][:,non_zero_voxels] - num_priors = prior_maps.shape[0] DR_maps_list=np.array(FC_maps_dict['DR']) + + if self.inputs.group_avg_prior: + num_priors = DR_maps_list.shape[1] + prior_maps = np.median(DR_maps_list,axis=0)[:,non_zero_voxels] + else: + prior_maps = scan_data['prior_maps'][:,non_zero_voxels] + num_priors = prior_maps.shape[0] + for i in range(num_priors): if self.inputs.network_weighting=='relative': network_var=None @@ -393,6 +401,13 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors): NPR_maps_list=np.array(FC_maps_dict['NPR']) + if self.inputs.group_avg_prior: + num_priors = NPR_maps_list.shape[1] + prior_maps = np.median(NPR_maps_list,axis=0)[:,non_zero_voxels] + else: + prior_maps = scan_data['prior_maps'][:,non_zero_voxels] + num_priors = prior_maps.shape[0] + if NPR_maps_list.shape[1]>0: for i in range(num_priors): if self.inputs.network_weighting=='relative': @@ -414,17 +429,25 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors): analysis_QC_network_i(i,FC_maps_,prior_maps[i,:],non_zero_mask, corr_variable_, variable_name, template_file, out_dir_parametric, out_dir_non_parametric, analysis_prefix='NPR') - # prior maps are provided for seed-FC, tries to run the diagnosis on seeds - if len(self.inputs.seed_prior_maps)>0: - prior_maps=[] - for prior_map in self.inputs.seed_prior_maps: - # resample to match the subject - sitk_img = sitk.Resample(sitk.ReadImage(prior_map), sitk.ReadImage(mask_file)) - prior_maps.append(sitk.GetArrayFromImage(sitk_img)[volume_indices]) - - prior_maps = np.array(prior_maps)[:,non_zero_voxels] - num_priors = prior_maps.shape[0] + if self.inputs.group_avg_prior or (len(self.inputs.seed_prior_maps)>0): seed_maps_list=np.array(FC_maps_dict['SBC']) + + if self.inputs.group_avg_prior: + num_priors = seed_maps_list.shape[1] + prior_maps = np.median(seed_maps_list,axis=0)[:,non_zero_voxels] + + # prior maps are provided for seed-FC, tries to run the diagnosis on seeds + elif len(self.inputs.seed_prior_maps)>0: + prior_maps=[] + for prior_map in self.inputs.seed_prior_maps: + # resample to match the subject + sitk_img = sitk.Resample(sitk.ReadImage(prior_map), sitk.ReadImage(mask_file)) + prior_maps.append(sitk.GetArrayFromImage(sitk_img)[volume_indices]) + prior_maps = np.array(prior_maps)[:,non_zero_voxels] + num_priors = prior_maps.shape[0] + else: + raise + for i in range(num_priors): network_var=None diff --git a/rabies/parser.py b/rabies/parser.py index ed44c45..edae30f 100644 --- a/rabies/parser.py +++ b/rabies/parser.py @@ -860,6 +860,15 @@ def get_parser(): "(default: %(default)s)\n" "\n" ) + analysis.add_argument( + "--group_avg_prior", dest='group_avg_prior', action='store_true', + help= + "Select this option to use the group average (the median across subject) as a reference prior network map \n" + "instead of providing an external image. This option will circumvent --seed_prior_list, and the ICA \n" + "components selected with --prior_bold_idx won't be used for computing Dice overlap measures during QC. \n" + "(default: %(default)s)\n" + "\n" + ) analysis.add_argument( '--scan_QC_thresholds', type=str, default="{}", help= diff --git a/scripts/error_check_rabies.py b/scripts/error_check_rabies.py index a46f44d..10758cf 100755 --- a/scripts/error_check_rabies.py +++ b/scripts/error_check_rabies.py @@ -215,7 +215,7 @@ def get_parser(): # testing group level --data_diagnosis command = f"rabies --force --verbose 1 analysis {tmppath}/outputs {tmppath}/outputs --NPR_temporal_comp 1 \ - --data_diagnosis --extended_QC --DR_ICA --seed_list {tmppath}/inputs/token_mask_half.nii.gz" + --data_diagnosis --group_avg_prior --extended_QC --DR_ICA --seed_list {tmppath}/inputs/token_mask_half.nii.gz" process = subprocess.run( command, check=True,