Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented a --group_avg_prior to compute prior networks from the gr… #355

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rabies/analysis_pkg/diagnosis_pkg/diagnosis_wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def prep_scan_data(dict_file, spatial_info, temporal_info):
DatasetDiagnosis_node.inputs.scan_QC_thresholds = analysis_opts.scan_QC_thresholds
DatasetDiagnosis_node.inputs.figure_format = analysis_opts.figure_format
DatasetDiagnosis_node.inputs.extended_QC = analysis_opts.extended_QC
DatasetDiagnosis_node.inputs.group_avg_prior = analysis_opts.group_avg_prior

workflow.connect([
(inputnode, prep_scan_data_node, [
Expand Down
47 changes: 35 additions & 12 deletions rabies/analysis_pkg/diagnosis_pkg/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class DatasetDiagnosisInputSpec(BaseInterfaceInputSpec):
desc="Whether network maps are absolute or relative.")
scan_QC_thresholds = traits.Dict(
desc="Specifications for scan-level QC thresholds.")
group_avg_prior = traits.Bool(
desc="Whether to use the group average (median) as a network prior instead of using an external image.")
figure_format = traits.Str(
desc="Select file format for figures.")
extended_QC = traits.Bool(
Expand Down Expand Up @@ -368,10 +370,16 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors):

scan_QC_thresholds = self.inputs.scan_QC_thresholds

prior_maps = scan_data['prior_maps'][:,non_zero_voxels]
num_priors = prior_maps.shape[0]

DR_maps_list=np.array(FC_maps_dict['DR'])

if self.inputs.group_avg_prior:
num_priors = DR_maps_list.shape[1]
prior_maps = np.median(DR_maps_list,axis=0)[:,non_zero_voxels]
else:
prior_maps = scan_data['prior_maps'][:,non_zero_voxels]
num_priors = prior_maps.shape[0]

for i in range(num_priors):
if self.inputs.network_weighting=='relative':
network_var=None
Expand All @@ -393,6 +401,13 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors):


NPR_maps_list=np.array(FC_maps_dict['NPR'])
if self.inputs.group_avg_prior:
num_priors = NPR_maps_list.shape[1]
prior_maps = np.median(NPR_maps_list,axis=0)[:,non_zero_voxels]
else:
prior_maps = scan_data['prior_maps'][:,non_zero_voxels]
num_priors = prior_maps.shape[0]

if NPR_maps_list.shape[1]>0:
for i in range(num_priors):
if self.inputs.network_weighting=='relative':
Expand All @@ -414,17 +429,25 @@ def prep_QC_thresholds_i(scan_QC_thresholds, analysis, network_i, num_priors):

analysis_QC_network_i(i,FC_maps_,prior_maps[i,:],non_zero_mask, corr_variable_, variable_name, template_file, out_dir_parametric, out_dir_non_parametric, analysis_prefix='NPR')

# prior maps are provided for seed-FC, tries to run the diagnosis on seeds
if len(self.inputs.seed_prior_maps)>0:
prior_maps=[]
for prior_map in self.inputs.seed_prior_maps:
# resample to match the subject
sitk_img = sitk.Resample(sitk.ReadImage(prior_map), sitk.ReadImage(mask_file))
prior_maps.append(sitk.GetArrayFromImage(sitk_img)[volume_indices])

prior_maps = np.array(prior_maps)[:,non_zero_voxels]
num_priors = prior_maps.shape[0]
if self.inputs.group_avg_prior or (len(self.inputs.seed_prior_maps)>0):
seed_maps_list=np.array(FC_maps_dict['SBC'])

if self.inputs.group_avg_prior:
num_priors = seed_maps_list.shape[1]
prior_maps = np.median(seed_maps_list,axis=0)[:,non_zero_voxels]

# prior maps are provided for seed-FC, tries to run the diagnosis on seeds
elif len(self.inputs.seed_prior_maps)>0:
prior_maps=[]
for prior_map in self.inputs.seed_prior_maps:
# resample to match the subject
sitk_img = sitk.Resample(sitk.ReadImage(prior_map), sitk.ReadImage(mask_file))
prior_maps.append(sitk.GetArrayFromImage(sitk_img)[volume_indices])
prior_maps = np.array(prior_maps)[:,non_zero_voxels]
num_priors = prior_maps.shape[0]
else:
raise

for i in range(num_priors):
network_var=None

Expand Down
9 changes: 9 additions & 0 deletions rabies/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,15 @@ def get_parser():
"(default: %(default)s)\n"
"\n"
)
analysis.add_argument(
"--group_avg_prior", dest='group_avg_prior', action='store_true',
help=
"Select this option to use the group average (the median across subject) as a reference prior network map \n"
"instead of providing an external image. This option will circumvent --seed_prior_list, and the ICA \n"
"components selected with --prior_bold_idx won't be used for computing Dice overlap measures during QC. \n"
"(default: %(default)s)\n"
"\n"
)
analysis.add_argument(
'--scan_QC_thresholds', type=str, default="{}",
help=
Expand Down
2 changes: 1 addition & 1 deletion scripts/error_check_rabies.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_parser():

# testing group level --data_diagnosis
command = f"rabies --force --verbose 1 analysis {tmppath}/outputs {tmppath}/outputs --NPR_temporal_comp 1 \
--data_diagnosis --extended_QC --DR_ICA --seed_list {tmppath}/inputs/token_mask_half.nii.gz"
--data_diagnosis --group_avg_prior --extended_QC --DR_ICA --seed_list {tmppath}/inputs/token_mask_half.nii.gz"
process = subprocess.run(
command,
check=True,
Expand Down
Loading