diff --git a/README.md b/README.md index 47d41d2..fc6f1ae 100644 --- a/README.md +++ b/README.md @@ -1 +1,78 @@ -# cogitate-meeg-analysis \ No newline at end of file +# MEG +This folder contains all the code created by Oscar Ferrante and Ling Liu in the frame of the COGITATE project. + +## Setup: +Create a new conda environment by running the following: +``` +conda env create --file=requirements_cogitate_meg.yaml +``` +For the linear mixed midel (LMM) analysis used in the activation analysis, create a specific LMM conda environment by running the following: +``` +conda env create --file=requirements_cogitate_meg_lmm.yaml +``` +The environments are tailored for Linux and the HPC, so some things might break a little if you use windows or Mac (not tested very thoroughly). + +In order to recreate the exact environment (reproducibility purposes) in which the code was developed, requirements files with build are also provided. +- requirements_cogitate_meg_exact.yml +- requirements_cogitate_meg_lmm_exact.yml + +**Installation time ~= 90min** + +## Change root path: +To run the analysis described below on the sample data, make sure to change the bids root path in /meeg/config/config.py: +*$ROOT/sample_data/bids* + +# Sample data and demo + +Sample data, used to run a demo of the analysis pipeline, can be found [here](https://keeper.mpdl.mpg.de/d/a7b65a9ccc2745d58268/) + +MEG data from four subjects (two per data collection site) are provided. We provide bids converted data as well as preprocessed data (in `./derivatives/preprocessing/` and `./derivatives/fs/`). + +In order to run the demo, edit the scripts so that the bids paths point to the downloaded data. + +### Running preprocessing: +In the command line, enter: +``` +python REPO_ROOT/cogitate-msp1/scripts/meeg/preprocessing/99_run_preproc.py --sub SA124 --visit V1 --record run --step 1 +``` +When the first preprocessing step is finished, enter: +``` +python REPO_ROOT/cogitate-msp1/scripts/meeg/preprocessing/P99_run_preproc.py --sub SA124 --visit V1 --record run --step 2 +``` +Expected output: the script should generate a directory under: +*$ROOT/sample_data/bids/derivatives/preprocessing/sub-SA124* +containing several subfolders, one for each preprocessing steps. The epoching files contain the final state of +the data ready for the next analysis steps. + +**Run time ~= 90min** + +### Running analyses: +For each analysis, run the scripts in the corresponding analysis folder (e.g., /meeg/activation) following the order +reported in the file name (e.g., first run "S01_source_loc.py", then "S02_source_loc_ga.py" and so on). +To run any of the individual-level analysis, enter: +``` +python REPO_ROOT/cogitate-msp1/scripts/meeg/ANALYSIS_FOLDER/ANALYSIS_CODE.py --sub SA124 --visit V1 +``` +Replace ANALYSIS_FOLDER with the name of the folder corresponding to the analysis you want to run (e.g., activation) +and ANALYSIS_CODE with the name of the script you want to execute (e.g., S01_source_loc.py). +To run any of the group-level analysis (i.e., these analyses are marked in the script file name with the suffix "ga"), enter: +``` +python REPO_ROOT/cogitate-msp1/scripts/meeg/ANALYSIS_FOLDER/ANALYSIS_CODE.py +``` + +## List of analysis and corresponding run time +- activation: +**Individual-level analysis run time ~= 60min per participant** +**Group-level analysis run time ~= 240min** +- connectivity +**Individual-level analysis run time ~= 90min per participant** +**Group-level analysis run time ~= 30min** +- ged (to be run before the connectivtiy analysis) +**Individual-level analysis run time ~= 15min per participant** +**Group-level analysis run time ~= 10min** +- roi_mvpa +**Individual-level analysis run time ~= XXmin per participant** +**Group-level analysis run time ~= XXmin** +- source_modelling +**Individual-level analysis run time ~= 210min per participant** +**Group-level analysis run time ~= 60min** diff --git a/about.md b/about.md new file mode 100644 index 0000000..01e552d --- /dev/null +++ b/about.md @@ -0,0 +1,19 @@ +# About + +Please refer to the README.md file for usage instructions to run the code. + +## Information + +| | | +| --- | --- | +author_name | Oscar Ferrante, Ling Liu +author_affiliation | University of Birmingham’s Centre for Human Brain Health (CHBH), Peking University (PKU) +author_email | O.Ferrante@bham.ac.uk, ling.liu@pku.edu.cn +PI_name | Ole Jensen, Huan Luo +PI_affiliation | University of Birmingham’s Centre for Human Brain Health (CHBH), Peking University (PKU) +PI_email | O.Jensen@bham.ac.uk, huan.luo@pku.edu.cn +programming_language | python +Is a readme file included with detailed instructions for running the code? | README.md +Is the environment file provided? | requirements_cogitate_meg.yml, requirements_cogitate_meg_lmm.yml +Is there a config file provided to change runtime parameters? | yes +Does the code run on the sample dataset? | yes diff --git a/activation/S01_source_loc.py b/activation/S01_source_loc.py new file mode 100644 index 0000000..805d2d6 --- /dev/null +++ b/activation/S01_source_loc.py @@ -0,0 +1,399 @@ +""" +================ +S02. Source localization of frequency-band-specific activity +================ + +Compute LCMV and DICS beamforming. + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +# import numpy as np +# import matplotlib.pyplot as plt +import argparse + +import mne +from mne.cov import compute_covariance +# from mne.beamformer import make_lcmv, apply_lcmv_cov, make_dics, apply_dics_csd +# from mne.time_frequency import csd_multitaper +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA113', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V2', + help='visit_id (e.g. "V1")') +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution ("lcmv", "dics", "dspm")') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +inv_method = opt.method + +debug = True +use_rs_noise = True + + +def run_sourcerecon(subject_id, visit_id): + # Set path to preprocessing derivatives and create the related folders + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + + stfr_deriv_root = op.join(bids_root, "derivatives", "source_loc") + if not op.exists(stfr_deriv_root): + os.makedirs(stfr_deriv_root) + stfr_figure_root = op.join(stfr_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(stfr_figure_root): + os.makedirs(stfr_figure_root) + + print("Processing subject: %s" % subject_id) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs( + bids_path_epo.fpath, + preload=False) + + # Read resting-state data + if visit_id == "V2" and use_rs_noise: + bids_path_rs = bids_path_epo.copy().update( + task="rest", + check=False) + + epochs_rs = mne.read_epochs( + bids_path_rs.fpath, + preload=False) + + # Pick trials + if visit_id == "V1": + epochs = epochs['Task_relevance in ["Relevant non-target", "Irrelevant"]'] + if debug: + epochs = epochs[0:100] + + # Select sensor type + epochs.load_data().pick('meg') + if visit_id == "V2" and use_rs_noise: + epochs_rs.load_data().pick('meg') + + # Baseline correction + baseline_win = (-0.5, 0.) + active_win = (.0, .5) + if visit_id == "V1" or not use_rs_noise: + epochs.apply_baseline(baseline=baseline_win) + + # Compute rank + rank = mne.compute_rank(epochs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + if inv_method == 'dspm': + space = "surface" + else: + space = "volume" + + if visit_id == "V1": + task = None + elif visit_id == "V2": + task = "vg" + + bids_path_fwd = bids_path_epo.copy().update( + root=fwd_deriv_root, + task=task, + suffix=space+"_fwd", + extension=".fif", + check=False) + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Loop iver frequency bands + for fr_band in ['alpha', 'gamma']: + + # Filter data + if fr_band == "alpha": + fmin = 8 + fmax = 13 + # bandwidth = 2. + elif fr_band == "gamma": + fmin = 60 + fmax = 90 + # bandwidth = 4. + else: + raise ValueError("Error: 'band' value not valid") + + epochs_band = epochs.copy().filter(fmin, fmax) + if visit_id == "V2" and use_rs_noise: + epochs_rs_band = epochs_rs.copy().filter(fmin, fmax) + + # Source modelling + # if inv_method == 'lcmv': + + # # Compute covariance matrices + # noise_cov = compute_covariance(epochs_band, + # tmin=baseline_win[0], + # tmax=baseline_win[1], + # method='empirical', + # rank=rank) + # active_cov = compute_covariance(epochs_band, + # tmin=active_win[0], + # tmax=active_win[1], + # method='empirical', + # rank=rank) + # common_cov = noise_cov + active_cov + + # # Generate LCMV filter + # filters = make_lcmv(epochs_band.info, + # fwd, + # common_cov, + # noise_cov=noise_cov, + # reg=0, + # depth=0, + # pick_ori='max-power', + # rank=rank, + # weight_norm=None, + # reduce_rank=True) + + # elif inv_method == 'dics': + + # # Compute cross-spectral density matrices + # noise_csd = csd_multitaper(epochs_band, + # fmin=fmin, fmax=fmax, + # tmin=baseline_win[0], tmax=baseline_win[1]) + # common_csd = csd_multitaper(epochs_band, + # fmin=fmin, fmax=fmax, + # tmin=baseline_win[0], tmax=active_win[1]) + + # # Generate DICS filter + # filters = make_dics(epochs_band.info, + # fwd, + # common_csd.mean(), + # noise_csd=noise_csd.mean(), + # reg=0, + # pick_ori='max-power', + # reduce_rank=True, + # real_filter=True, + # rank=rank, + # depth=0) + + # elif inv_method == "dspm": + + # Compute covariance matrices + if visit_id == "V1" or not use_rs_noise: + noise_cov = compute_covariance(epochs_band, + tmin=baseline_win[0], + tmax=baseline_win[1], + method='empirical', + rank=rank) + elif visit_id == "V2": + noise_cov = compute_covariance(epochs_rs_band, + method='empirical', + rank=rank) + + active_cov = compute_covariance(epochs_band, + tmin=active_win[0], + tmax=active_win[1], + method='empirical', + rank=rank) + common_cov = noise_cov + active_cov + + # Make inverse operator + filters = mne.minimum_norm.make_inverse_operator( + epochs_band.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + # else: + # raise ValueError("Error: 'inv_method' value not valid") + + for condition in range(1,3): + + # Pick condition + if visit_id == "V1": + if condition == 1: + epochs_cond = epochs_band['Task_relevance == "Relevant non-target"'].copy() + cond_name = "relevant non-target" + elif condition == 2: + epochs_cond = epochs_band['Task_relevance == "Irrelevant"'].copy() + cond_name = "irrelevant" + else: + raise ValueError("Condition %s does not exists" % condition) + elif visit_id == "V2": + if condition == 1: + epochs_cond = epochs_band.copy() + if use_rs_noise: + cond_name = "all" + else: + cond_name = "all_vgbase" + else: + continue + + print(f"\n\n\n### Running on task {cond_name} ###\n\n") + + # Apply filter + # if inv_method == 'lcmv': + + # # Compute covariance matrices + # act_cov_cond = compute_covariance(epochs_cond, + # tmin=active_win[0], + # tmax=active_win[1], + # method='empirical', + # rank=rank) + # noise_cov_cond = compute_covariance(epochs_cond, + # tmin=baseline_win[0], + # tmax=baseline_win[1], + # method='empirical', + # rank=rank) + + # # Apply LCMV filter + # stc_act = apply_lcmv_cov(act_cov_cond, filters) + # stc_base = apply_lcmv_cov(noise_cov_cond, filters) + + # elif inv_method == 'dics': + + # # Compute cross-spectral density matrices + # act_csd_cond = csd_multitaper(epochs_cond, + # fmin=fmin, fmax=fmax, + # tmin=active_win[0], + # tmax=active_win[1], + # bandwidth=bandwidth) + # base_csd_cond = csd_multitaper(epochs_cond, + # fmin=fmin, fmax=fmax, + # tmin=baseline_win[0], + # tmax=baseline_win[1], + # bandwidth=bandwidth) + + # # Apply DICS filter + # stc_base, freqs = apply_dics_csd(base_csd_cond.mean(), filters) + # stc_act, freqs = apply_dics_csd(act_csd_cond.mean(), filters) + + # elif inv_method == "dspm": + + # Compute covariance matrices + act_cov_cond = compute_covariance(epochs_cond, + tmin=active_win[0], + tmax=active_win[1], + method='empirical', + rank=rank) + if visit_id == "V1" or not use_rs_noise: + noise_cov_cond = compute_covariance(epochs_cond, + tmin=baseline_win[0], + tmax=baseline_win[1], + method='empirical', + rank=rank) + elif visit_id == "V2": + noise_cov_cond = compute_covariance(epochs_rs_band, + method='empirical', + rank=rank) + + # Apply dSPM filter + stc_act = mne.minimum_norm.apply_inverse_cov(act_cov_cond, + epochs_cond.info, + filters, + method='dSPM', + pick_ori=None, + verbose=True) + if visit_id == "V1" or not use_rs_noise: + stc_base = mne.minimum_norm.apply_inverse_cov(noise_cov_cond, + epochs_cond.info, + filters, + method='dSPM', + pick_ori=None, + verbose=True) + elif visit_id == "V2": + stc_base = mne.minimum_norm.apply_inverse_cov(noise_cov_cond, + epochs_rs_band.info, + filters, + method='dSPM', + pick_ori=None, + verbose=True) + + # else: + # raise ValueError("Error: 'inv_method' value not valid") + + # Compute baseline correction + stc_act /= stc_base + + # Save source estimates + bids_path_con = bids_path_epo.copy().update( + root=stfr_deriv_root, + suffix=f"stfr_beam-{inv_method}_band-{fr_band}_c-{cond_name}", + extension=None, + check=False) + + stc_act.save(bids_path_con) + + # Morph to fsaverage #not needed if morphing the forward solution + if inv_method in ["lcmv", "dics"]: + fname_fs_src = fs_deriv_root + '/fsaverage/bem/fsaverage-vol-5-src.fif' + elif inv_method == "dspm": + fname_fs_src = fs_deriv_root + '/fsaverage/bem/fsaverage-ico-5-src.fif' + + src_fs = mne.read_source_spaces(fname_fs_src) + + morph = mne.compute_source_morph( + fwd['src'], + subject_from="sub-"+subject_id, + subject_to='fsaverage', + src_to=src_fs, + subjects_dir=fs_deriv_root, + verbose=True) + + stc_fs = morph.apply(stc_act) + + # Save morphed source estimates + bids_path_sou = bids_path_epo.copy().update( + root=stfr_deriv_root, + suffix=f"stfr_beam-{inv_method}_band-{fr_band}_c-{cond_name}_morph", + extension=None, + check=False) + + stc_fs.save(bids_path_sou) + + +if __name__ == '__main__': + run_sourcerecon(subject_id, visit_id) + \ No newline at end of file diff --git a/activation/S02_source_loc_ga.py b/activation/S02_source_loc_ga.py new file mode 100644 index 0000000..3893fa0 --- /dev/null +++ b/activation/S02_source_loc_ga.py @@ -0,0 +1,255 @@ +""" +================ +S03. Grand-average source localization +================ + +Grand-average of source localization. + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +import argparse +import matplotlib +import matplotlib.pyplot as plt + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution ("lcmv", "dics", "dspm")') +parser.add_argument('--band', + type=str, + default='gamma', + help='frequency band of interest ("alpha", "beta", "gamma")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +inv_method = opt.method +fr_band = opt.band +visit_id = "V1" + +debug = False + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA126"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def source_loc_ga(): + # Set directory paths + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + + stfr_deriv_root = op.join(bids_root, "derivatives", "source_loc") + if not op.exists(stfr_deriv_root): + os.makedirs(stfr_deriv_root) + stfr_figure_root = op.join(stfr_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(stfr_figure_root): + os.makedirs(stfr_figure_root) + + # Set file name ending + if inv_method in ["lcmv", "dics"]: + fname_end = 'vl' + elif inv_method == "dspm": + fname_end = "lh" + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Load average source space + if inv_method in ["lcmv", "dics"]: + fname_fs_src = op.join(fs_deriv_root, 'fsaverage/bem/fsaverage-vol-5-src.fif') + elif inv_method == "dspm": + fname_fs_src = op.join(fs_deriv_root, 'fsaverage/bem/fsaverage-ico-5-src.fif') + src_fs = mne.read_source_spaces(fname_fs_src) + + # Loop over frequency bands + for fr_band in ['alpha', 'gamma']: + # Loop over conditions + stcs = {} + for condition in range(1,3): + + # Pick condition + if condition == 1: + cond_name = "relevant non-target" + elif condition == 2: + cond_name = "irrelevant" + else: + raise ValueError("Condition %s does not exists" % condition) + + print(f"\Task {cond_name}") + + # Load data + stcs_temp = [] + for sub in sub_list: + print("participant:", sub) + + # Set path + bids_path_sou = mne_bids.BIDSPath( + root=stfr_deriv_root, + subject=sub, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f"stfr_beam-{inv_method}_band-{fr_band}_c-{cond_name}-{fname_end}", + extension=".stc", + check=False) + + # Load stc data + stc = mne.read_source_estimate(bids_path_sou) + + # Read forward solution + bids_path_fwd = bids_path_sou.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Morph to fsaverage + if sub not in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + morph = mne.compute_source_morph( + fwd['src'], + subject_from="sub-"+sub, + subject_to='fsaverage', + src_to=src_fs, + subjects_dir=fs_deriv_root, + verbose=True) + + stc = morph.apply(stc) + + # Append to temp stc list + stcs_temp.append(stc) + + # Appenmd to full stcs list + stcs[cond_name] = stcs_temp + + del stc, stcs_temp + + # Average stcs across participants + stcs_data = [stc.data for stc in stcs[cond_name]] + stc_ga = stcs[cond_name][0] + stc_ga.data = np.mean(stcs_data, axis=0) + + # Save stc grandaverage + bids_path_sou = bids_path_sou.update( + subject=f"groupphase{phase}") + stc_ga.save(bids_path_sou) + + +def plot_source_loc_ga(stc_path, desc=None, lims=None, hemi="lh", surface='pial', + size=(800, 600), colormap="RdYlBu_r", colorbar=False, + transparent=False, background="white", subject="fsaverage", + subjects_dir=r'C:\Users\ferranto\Desktop\fs'): + + # Load stc data + stc_ga = mne.read_source_estimate(stc_path) + + # Set view + if desc: + views = ['lateral', 'ventral', 'caudal', 'medial'] + else: + views = ['lateral'] + + # Set limits + if lims == None: + if desc == "alpha": + lims=[0.8,1.,1.2] + elif desc == "gamma": + lims=[0.95,1.,1.05] + else: + lims=[0.8,1.,1.2] + + # Plot source estimates + for view in views: + fig = stc_ga.plot( + hemi=hemi, + views=view, + surface=surface, + size=size, + colormap=colormap, + colorbar=colorbar, + transparent=transparent, + background=background, + subject=subject, + clim={"kind": "value", "lims": lims}, + subjects_dir=subjects_dir) + + # Save figure + if desc: + fname_fig = op.join(op.dirname(stc_path), + f"stfr_{desc}_{view}.png") + fig.save_image(fname_fig) + fig.close() + + # Plot colorbar separately + if desc: + # Create a figure and axes object + fig, ax = plt.subplots(figsize=[3, 2]) + + # Create a colorbar + norm = matplotlib.colors.Normalize(vmin=lims[0], vmax=lims[-1]) + cb = fig.colorbar( + matplotlib.cm.ScalarMappable(norm=norm, cmap="RdYlBu_r"), + aspect=10, ax=ax) + + # Get the axes object for the colorbar + cb_ax = cb.ax + + # Remove the other axes objects from the figure + for ax in fig.axes: + if ax != cb_ax: + ax.remove() + + # Save the figure + fname_fig = os.path.join(os.path.dirname(stc_path), + f"stfr_{desc}__colorbar.png") + fig.savefig(fname_fig, dpi=300) + fname_fig = os.path.join(os.path.dirname(stc_path), + f"stfr_{desc}__colorbar.svg") + fig.savefig(fname_fig, dpi=300) + plt.close() + + return fig + + +if __name__ == '__main__': + source_loc_ga() diff --git a/activation/S03a_source_dur_spectral.py b/activation/S03a_source_dur_spectral.py new file mode 100644 index 0000000..2246163 --- /dev/null +++ b/activation/S03a_source_dur_spectral.py @@ -0,0 +1,331 @@ +""" +================ +S04. Source localization of frequency-band-specific duration activity +================ + + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import argparse +import itertools +import json +import pandas as pd + + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution ("lcmv", "dics", "dspm")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +inv_method = opt.method #this variable is used only to set the output filename + +debug = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +def run_source_dur(subject_id, visit_id): + # Set directory paths + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + rois_deriv_root = op.join(bids_root, "derivatives", "roilabel") + source_deriv_root = op.join(bids_root, "derivatives", "source_dur") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs_all = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick trials + epochs_all = epochs_all['Task_relevance in ["Relevant non-target", "Irrelevant"]'] + if debug: + epochs_all = epochs_all[0:100] + + # Select sensor type + epochs_all.load_data().pick('meg') + + # Run baseline correction + b_tmin = -.5 + b_tmax = 0. + baseline = (b_tmin, b_tmax) + epochs_all.apply_baseline(baseline=baseline) + + # Read labels from FS parc + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + labels_atlas = mne.read_labels_from_annot( + "fsaverage", + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + else: + labels_atlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + + # labels_atlas_names = [l.name for l in labels_atlas] + + # Read GNW and IIT ROI list + f = open(op.join(rois_deriv_root, + 'iit_gnw_rois.json')) + gnw_iit_rois = json.load(f) + + # Create labels for selected ROIs + labels = {} + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + for lab in gnw_iit_rois['surf_labels']['iit_1']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + # labels["iit_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"] + # labels["iit_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"] + labels["iit_"+lab.replace('_and_','&')] = np.sum([l for l in labels_atlas if lab in l.name]) + + for lab in gnw_iit_rois['surf_labels']['gnw']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + # labels["gnw_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"] + # labels["gnw_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"] + labels["gnw_"+lab.replace('_and_','&')] = np.sum([l for l in labels_atlas if lab in l.name]) + else: + for lab in gnw_iit_rois['surf_labels']['iit_1']: + print(lab) + # labels["iit_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"][0] + # labels["iit_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"][0] + labels["iit_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + + for lab in gnw_iit_rois['surf_labels']['gnw']: + print(lab) + # labels["gnw_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"][0] + # labels["gnw_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"][0] + labels["gnw_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + + # Merge all labels in a single one separatelly for GNW and IIT + labels['gnw_all'] = np.sum([l for l_name, l in labels.items() if 'gnw' in l_name]) + labels['iit_all'] = np.sum([l for l_name, l in labels.items() if 'iit' in l_name]) + + # Compute rank + rank = mne.compute_rank(epochs_all, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = bids_path_epo.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs_all, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + + active_cov = mne.compute_covariance(epochs_all, + tmin=0, + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Make inverse operator + inverse_operator = mne.minimum_norm.make_inverse_operator( + epochs_all.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + # Find all combinations between variables' levels + if len(factor) == 1: + cond_combs = list(itertools.product(conditions[0])) + if len(factor) == 2: + cond_combs = list(itertools.product(conditions[0], + conditions[1])) + if len(factor) == 3: + cond_combs = list(itertools.product(conditions[0], + conditions[1], + conditions[2])) + + # Set band-sepcific params + b_params = { + 'alpha': { + 'bands': dict(alpha=[8, 13]), + 'n_cycles': np.arange(8, 14, 1) / 2., + 'df': 1, + 'baseline': (-.75, -.25)}, + 'gamma': { + 'bands': dict(gamma=[60, 90]), + 'n_cycles': np.arange(60, 91, 2) / 4., + 'df': 2, + 'baseline': (-.375, -.125)} } + + # Loop over conditions of interest + for cond_comb in cond_combs: + print("\nAnalyzing %s: %s" % (factor, cond_comb)) + + # Select epochs + if len(factor) == 1: + epochs = epochs_all['%s == "%s"' % ( + factor[0], cond_comb[0])] + fname = cond_comb[0] + if len(factor) == 2: + epochs = epochs_all['%s == "%s" and %s == "%s"' % ( + factor[0], cond_comb[0], + factor[1], cond_comb[1])] + fname = cond_comb[0] + "_" + cond_comb[1] + if len(factor) == 3: + epochs = epochs_all['%s == "%s" and %s == "%s" and %s == "%s"' % ( + factor[0], cond_comb[0], + factor[1], cond_comb[1], + factor[2], cond_comb[2])] + fname = cond_comb[0] + "_" + cond_comb[1] + "_" + cond_comb[2] + + # Compute inverse solution for each epoch + stcs = {} + for band_name in ['alpha', 'gamma']: + print(f"band: {band_name}") + stcs.update(mne.minimum_norm.source_band_induced_power( + epochs, + inverse_operator, + bands = b_params[band_name]['bands'], + method='dSPM', + n_cycles=b_params[band_name]['n_cycles'], + df=b_params[band_name]['df'], + baseline=b_params[band_name]['baseline'], + baseline_mode='ratio', + use_fft=True)) + + # Save stcs + for band, stc in stcs.items(): + bids_path_source = bids_path_epo.copy().update( + root=source_deriv_root, + suffix=f"desc-{fname},{band}_stc", + extension=None, + check=False) + + stc.save(bids_path_source) + + # Loop over bands + for band, stc in stcs.items(): + print(f"\nPlotting {band}") + # Loop over labels + for label_name, label in labels.items(): + print(f"label: {label_name}") + + # Select data in label + stc_in = stc.in_label(label) + + # Extract time course data + times = stc_in.times + data = stc_in.data.mean(axis=0) + + # Create and save a tsv table with the label time course data + df = pd.DataFrame({ + 'times': times, + 'data': data}) + + bids_path_source = bids_path_epo.copy().update( + root=source_deriv_root, + suffix=f"desc-{fname},{band},{label_name}_datatable", + extension='.tsv', + check=False) + df.to_csv(bids_path_source.fpath, sep="\t") + + # Plot + tmin = (np.abs(times - -.5)).argmin() + tmax = (np.abs(times - 2.)).argmin() + + plt.plot(times[tmin:tmax], data[tmin:tmax]) + plt.xlabel('Time (ms)') + plt.ylabel('Power') + plt.title(f'{band} power in {label_name}:\n{fname}') + + # Save figure + fname_fig = op.join(source_figure_root, + f'source_dur_{fname}_{band}_{label_name}.png') + plt.savefig(fname_fig) + plt.close('all') + + # Save label names + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix="desc-labels", + extension='.txt', + check=False) + + with open(bids_path_source.fpath, "w") as output: + output.write(str(list(labels.keys()))) + + +if __name__ == '__main__': + run_source_dur(subject_id, visit_id) \ No newline at end of file diff --git a/activation/S03b_source_dur_erf.py b/activation/S03b_source_dur_erf.py new file mode 100644 index 0000000..0065ee9 --- /dev/null +++ b/activation/S03b_source_dur_erf.py @@ -0,0 +1,330 @@ +""" +================ +S07. Source localization of ERF duration activity + +based on S04 Source Localization of frequency-band-specific duration activity @author: Oscar Ferrante oscfer88@gmail.com +================ + + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import argparse +import itertools +import json +import pandas as pd + +import mne +import mne_bids +from mne.minimum_norm import apply_inverse + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root, sfreq + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution ("lcmv", "dics", "dspm")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +inv_method = opt.method #this variable is used only to set the output filename + +debug = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +def run_source_dur(subject_id, visit_id): + # Set directory paths + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + rois_deriv_root = op.join(bids_root, "derivatives", "roilabel") + source_deriv_root = op.join(bids_root, "derivatives", "source_dur_ERF") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs_all = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick trials + epochs_all = epochs_all['Task_relevance in ["Relevant non-target", "Irrelevant"]'] + if debug: + epochs_all = epochs_all[0:100] + + # Select sensor type + epochs_all.load_data().pick('meg') + + # Downsample and filter to speed the decoding + # Downsample copy of raw + epochs_rs = epochs_all.copy().resample(sfreq, n_jobs=-1) + # Band-pass filter raw copy + epochs_rs.filter(0, 30, n_jobs=-1) + + epochs_rs.crop(tmin=-0.5, tmax=2,include_tmax=True, verbose=None) + + # Run baseline correction + b_tmin = -.5 + b_tmax = 0. + baseline = (b_tmin, b_tmax) + epochs_rs.apply_baseline(baseline=baseline) + + # Read labels from FS parc + if subject_id in ['SA102','SA104','SA110', 'SA111','SA152']: + labels_atlas = mne.read_labels_from_annot( + "fsaverage", + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + else: + labels_atlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + + # labels_atlas_names = [l.name for l in labels_atlas] + + # Read GNW and IIT ROI list + f = open(op.join(rois_deriv_root, + 'iit_gnw_rois.json')) + gnw_iit_rois = json.load(f) + + # Create labels for selected ROIs + labels = {} + if subject_id in ['SA102','SA104','SA110', 'SA111','SA152']: + for lab in gnw_iit_rois['surf_labels']['iit_1']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["iit_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + + for lab in gnw_iit_rois['surf_labels']['gnw']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["gnw_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + else: + for lab in gnw_iit_rois['surf_labels']['iit_1']: + print(lab) + labels["iit_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + + for lab in gnw_iit_rois['surf_labels']['gnw']: + print(lab) + labels["gnw_"+lab] = np.sum([l for l in labels_atlas if lab in l.name]) + + # Merge all labels in a single one separatelly for GNW and IIT + labels['gnw_all'] = np.sum([l for l_name, l in labels.items() if 'gnw' in l_name]) + labels['iit_all'] = np.sum([l for l_name, l in labels.items() if 'iit' in l_name]) + + # Compute rank + rank = mne.compute_rank(epochs_rs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = bids_path_epo.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs_rs, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + + active_cov = mne.compute_covariance(epochs_rs, + tmin=0, + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Make inverse operator + inverse_operator = mne.minimum_norm.make_inverse_operator( + epochs_rs.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + # Find all combinations between variables' levels + if len(factor) == 1: + cond_combs = list(itertools.product(conditions[0])) + if len(factor) == 2: + cond_combs = list(itertools.product(conditions[0], + conditions[1])) + if len(factor) == 3: + cond_combs = list(itertools.product(conditions[0], + conditions[1], + conditions[2])) + + + + # Loop over conditions of interest + for cond_comb in cond_combs: + print("\nAnalyzing %s: %s" % (factor, cond_comb)) + + # Select epochs + if len(factor) == 1: + epochs = epochs_rs['%s == "%s"' % ( + factor[0], cond_comb[0])] + fname = cond_comb[0] + if len(factor) == 2: + epochs = epochs_rs['%s == "%s" and %s == "%s"' % ( + factor[0], cond_comb[0], + factor[1], cond_comb[1])] + fname = cond_comb[0] + "_" + cond_comb[1] + if len(factor) == 3: + epochs = epochs_rs['%s == "%s" and %s == "%s" and %s == "%s"' % ( + factor[0], cond_comb[0], + factor[1], cond_comb[1], + factor[2], cond_comb[2])] + fname = cond_comb[0] + "_" + cond_comb[1] + "_" + cond_comb[2] + + + # Get evoked response by computing the epoch average + evoked = epochs.average() + # Compute inverse solution for each epoch + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + stc = apply_inverse(evoked, inverse_operator, 1. / lambda2, 'dSPM', pick_ori="normal") + + # # Save stc + # bids_path_source = bids_path_epo.copy().update( + # root=source_deriv_root, + # suffix=f"desc-{fname},ERF_stc", + # extension='.stc', + # check=False) + + # stc.save(bids_path_source) + + + + # Loop over labels + for label_name, label in labels.items(): + print(f"label: {label_name}") + + # extract time course in label with pca_flip mode + src = inverse_operator['src'] + # Extract time course data + times = epochs.times + tcs = stc.extract_label_time_course(label,src,mode='mean') + data = tcs[0] + + # Convert to root mean square + data = np.sqrt((np.array(data)**2)) + + # Create and save a tsv table with the label time course data + df = pd.DataFrame({ + 'times': times, + 'data': data}) + + bids_path_source = bids_path_epo.copy().update( + root=source_deriv_root, + suffix=f"desc-{fname},ERF,{label_name}_datatable", + extension='.tsv', + check=False) + df.to_csv(bids_path_source.fpath, sep="\t") + + # Show results + # Apply low-pass filter + dataf = apply_fir_filter(data) + + # Set time limits + tmin = (np.abs(times - -.5)).argmin() + tmax = (np.abs(times - 2.)).argmin() + + # Plot + plt.plot(times[tmin:tmax], dataf[tmin:tmax]) + plt.xlabel('Time (ms)') + plt.ylabel('rms') + plt.title(f'ERF in {label_name}:\n{fname}') + + # Save figure + fname_fig = op.join(source_figure_root, + f'source_dur_{fname}_ERF_{label_name}.png') + plt.savefig(fname_fig) + plt.close('all') + + del stc, evoked + + +def apply_fir_filter(time_series, fs=1000, fc=30, num_taps=101): + # Compute the center of the filter taps + center = (num_taps - 1) // 2 + + # Define the filter coefficients for a low-pass filter with the given cutoff frequency + filter_coeffs = np.sinc(2 * fc / fs * np.arange(center - num_taps + 1, center + 1)) + + # Reverse the filter coefficients since convolution uses a flipped version + filter_coeffs = np.flip(filter_coeffs) + + # Apply the filter to the time series using convolution + filtered_series = np.convolve(time_series, filter_coeffs, mode='same') + + return filtered_series + + + +if __name__ == '__main__': + run_source_dur(subject_id, visit_id) \ No newline at end of file diff --git a/activation/S04a_source_dur_spectral_ga.py b/activation/S04a_source_dur_spectral_ga.py new file mode 100644 index 0000000..da3d1ea --- /dev/null +++ b/activation/S04a_source_dur_spectral_ga.py @@ -0,0 +1,272 @@ +""" +================ +S05. Grand-average source epochs +================ + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +# import matplotlib.pyplot as plt +import argparse +import itertools +import pandas as pd +import random +import string + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +inv_method = opt.method +visit_id = "V1" + +debug = False +bootstrap = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +elif bootstrap: + # Read the .txt file for phase 2 + f = open(op.join(bids_root, + 'participants_MEG_phase2_included.txt'), 'r').read() + # Split text into list of elements + sub_list_2 = f.split("\n") + + # Read the .txt file for phase 3 + f = open(op.join(bids_root, + 'participants_MEG_phase3_included.txt'), 'r').read() + # Split text into list of elements + sub_list_3 = f.split("\n") + + # Remove two participants from phase 3 and replace them with 2 random participants from phase 2 + removed_participants = ["SB003", "SB006"] #random.sample(sub_list_3, 2) + sub_list = [participant for participant in sub_list_3 if participant not in removed_participants] + random_participants = random.sample(sub_list_2, 2) + sub_list.extend(random_participants) + + # Replace phase number with "bootstrap" + phase = "bs" + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(4)) +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def source_dur_ga(): + # Set directory paths + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + source_deriv_root = op.join(bids_root, "derivatives", "source_dur") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Find all combinations between variables' levels + if len(factor) == 1: + cond_combs = list(itertools.product(conditions[0])) + if len(factor) == 2: + cond_combs = list(itertools.product(conditions[0], + conditions[1])) + if len(factor) == 3: + cond_combs = list(itertools.product(conditions[0], + conditions[1], + conditions[2])) + + # Read source space for morphing + fname_fs_src = fs_deriv_root + '/fsaverage/bem/fsaverage-ico-5-src.fif' + src_fs = mne.read_source_spaces(fname_fs_src) + + # Read list with label names + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=sub_list[0], + datatype='meg', + task=bids_task, + session=visit_id, + suffix="desc-labels", + extension='.txt', + check=False) + + labels_names = open(bids_path_source.fpath, 'r').read() + labels_names = labels_names[2:-2].split("', '") + + # Create empty dataframe + all_data_df = pd.DataFrame() + + # Loop over conditions of interest + for cond_comb in cond_combs: + print("\n\nAnalyzing %s: %s" % (factor, cond_comb)) + + # Select epochs + if len(factor) == 1: + fname = cond_comb[0] + if len(factor) == 2: + fname = cond_comb[0] + "_" + cond_comb[1] + if len(factor) == 3: + fname = cond_comb[0] + "_" + cond_comb[1] + "_" + cond_comb[2] + + for band in ['alpha', 'gamma']: + print('\n\nfreq_band:', band) + + # Loop over participants + stcs = [] + for sub in sub_list: + print('\nsubject:', sub) + + # Read individual stc + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=sub, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f"desc-{fname},{band}_stc", + extension=None, + check=False) + + stc = mne.read_source_estimate(bids_path_source) + + # Read forward solution + bids_path_fwd = bids_path_source.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Morph stc + if sub not in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + morph = mne.compute_source_morph( + fwd['src'], + subject_from=f"sub-{sub}", + subject_to='fsaverage', + src_to=src_fs, + subjects_dir=fs_deriv_root, + verbose=True) + + stc = morph.apply(stc) + + # Append to stcs list + stcs.append(stc) + + # Average stcs across participants + stcs_data = [stc.data for stc in stcs] + stc_ga = stcs[0] + stc_ga.data = np.mean(stcs_data, axis=0) + + # Save stc grandaverage + bids_path_source = bids_path_source.update( + subject=f"groupphase{phase}") + stc_ga.save(bids_path_source) + + # Loop over labels + for label in labels_names: + print('\nlabel:', label) + + # Create empty list + label_data = [] + + # Loop over participants + for sub in sub_list: + print('subject:', sub) + + # Read individual dataframe + bids_path_source = bids_path_source.update( + subject=sub, + suffix=f"desc-{fname},{band},{label}_datatable", + extension='.tsv', + check=False) + df = pd.read_csv(bids_path_source.fpath, sep="\t") + + # Append dataframe to list + label_data.append(df['data']) + + # Create table with the extracted label time course data + label_data_df = pd.DataFrame(sub_list,columns=['sub']) + label_data_df = pd.concat( + [label_data_df, + pd.DataFrame( + np.array(label_data), + columns=df['times'])], + axis=1) + + # Add info to the table regarding the conditions + if len(factor) == 1: + label_data_df[factor[0]] = cond_comb[0] + if len(factor) == 2: + label_data_df[factor[0]] = cond_comb[0] + label_data_df[factor[1]] = cond_comb[1] + if len(factor) == 3: + label_data_df[factor[0]] = cond_comb[0] + label_data_df[factor[1]] = cond_comb[1] + label_data_df[factor[2]] = cond_comb[2] + + label_data_df['band'] = band + label_data_df['label'] = label + + # Append label table to data table + all_data_df = all_data_df.append(label_data_df) + + # Save table as .tsv + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + subject=f"groupphase{phase}", + suffix="datatable", + check=False) + all_data_df.to_csv(bids_path_source.fpath, + sep="\t", + index=False) + + +if __name__ == '__main__': + source_dur_ga() diff --git a/activation/S04b_source_dur_erf_ga.py b/activation/S04b_source_dur_erf_ga.py new file mode 100644 index 0000000..324f3e2 --- /dev/null +++ b/activation/S04b_source_dur_erf_ga.py @@ -0,0 +1,191 @@ +""" +================ +S05. Grand-average source epochs +================ + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +import argparse +import itertools +import pandas as pd + +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution') +opt=parser.parse_args() + + +# Set params +inv_method = opt.method +visit_id = "V1" + +debug = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def source_dur_ga(): + # Set directory paths + source_deriv_root = op.join(bids_root, "derivatives", "source_dur_ERF") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Find all combinations between variables' levels + if len(factor) == 1: + cond_combs = list(itertools.product(conditions[0])) + if len(factor) == 2: + cond_combs = list(itertools.product(conditions[0], + conditions[1])) + if len(factor) == 3: + cond_combs = list(itertools.product(conditions[0], + conditions[1], + conditions[2])) + + # Read list with label names + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root[:-10], + subject=sub_list[0], + datatype='meg', + task=bids_task, + session=visit_id, + suffix="desc-labels", + extension='.txt', + check=False) + + labels_names = open(bids_path_source.fpath, 'r').read() + labels_names = labels_names[2:-2].split("', '") + + # Create empty dataframe + all_data_df = pd.DataFrame() + + # Loop over conditions of interest + for cond_comb in cond_combs: + print("\n\nAnalyzing %s: %s" % (factor, cond_comb)) + + # Select epochs + if len(factor) == 1: + fname = cond_comb[0] + if len(factor) == 2: + fname = cond_comb[0] + "_" + cond_comb[1] + if len(factor) == 3: + fname = cond_comb[0] + "_" + cond_comb[1] + "_" + cond_comb[2] + + # Loop over labels + for label in labels_names: + print('\nlabel:', label) + + # Create empty list + label_data = [] + + # Loop over participants + for sub in sub_list: + print('subject:', sub) + + # Read individual dataframe + try: + bids_path_source = bids_path_source.update( + root=source_deriv_root, + subject=sub, + suffix=f"desc-{fname},ERF,{label}_datatable", + extension='.tsv', + check=False) + df = pd.read_csv(bids_path_source.fpath, sep="\t") + except: + label_r = label.replace("&", "_and_") + bids_path_source = bids_path_source.update( + root=source_deriv_root, + subject=sub, + suffix=f"desc-{fname},ERF,{label_r}_datatable", + extension='.tsv', + check=False) + df = pd.read_csv(str(bids_path_source.fpath), sep="\t") + + # Append dataframe to list + label_data.append(df['data']) + + # Create table with the extracted label time course data + label_data_df = pd.DataFrame(sub_list,columns=['sub']) + label_data_df = pd.concat( + [label_data_df, + pd.DataFrame( + np.array(label_data), + columns=df['times'])], + axis=1) + + # Add info to the table regarding the conditions + if len(factor) == 1: + label_data_df[factor[0]] = cond_comb[0] + if len(factor) == 2: + label_data_df[factor[0]] = cond_comb[0] + label_data_df[factor[1]] = cond_comb[1] + if len(factor) == 3: + label_data_df[factor[0]] = cond_comb[0] + label_data_df[factor[1]] = cond_comb[1] + label_data_df[factor[2]] = cond_comb[2] + + label_data_df['band'] = "ERF" + label_data_df['label'] = label + + # Append label table to data table + all_data_df = all_data_df.append(label_data_df) + + # Save table as .tsv + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + subject=f"groupphase{phase}", + suffix="datatable", + check=False) + all_data_df.to_csv(bids_path_source.fpath, + sep="\t", + index=False) + + +if __name__ == '__main__': + source_dur_ga() diff --git a/activation/S05a_source_dur_spectral_lmm.py b/activation/S05a_source_dur_spectral_lmm.py new file mode 100644 index 0000000..23759c2 --- /dev/null +++ b/activation/S05a_source_dur_spectral_lmm.py @@ -0,0 +1,893 @@ +""" +======================================= +S06. Source spectal activation analysis +======================================= + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.ticker as tick +import pandas as pd +import statsmodels.formula.api as smf +import seaborn as sns +import ptitprince as pt #conda install -c conda-forge ptitprince + +from pymer4.models import Lmer #conda install -c ejolly -c conda-forge -c defaults pymer4 + +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +# Set params +visit_id = "V1" + + +debug = False +bootstrap = False + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant target','Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +elif bootstrap: + # Read the .txt file + f = open(op.join(bids_root, + 'participants_MEG_phase3_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + # Rename phase variable with the random name assigned to the bootstrap (MANUAL STEP!) + phase = "bsSGa5" +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def run_source_dur_activation(task_rel, tbins): + # Set directory paths + source_deriv_root = op.join(bids_root, "derivatives", "source_dur") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read list with label names + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=sub_list[0], + datatype='meg', + task=bids_task, + session=visit_id, + suffix="desc-labels", + extension='.txt', + check=False) + labels_names = open(bids_path_source.fpath, 'r').read() + labels_names = labels_names[2:-2].split("', '") + + # Read dataframe + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=f"groupphase{phase}", + datatype='meg', + task=bids_task, + session=visit_id, + suffix="datatable", + extension='.tsv', + check=False) + df = pd.read_csv(bids_path_source.fpath, sep="\t") + + # Rename "False" as "false" + df.loc[df['Category']==False, 'Category'] = 'false' + + # Average data in the three time window of interest + times = np.array([float(t) for t in list(df.columns)[1:-5]]) + for tmin, tmax in tbins: + imin = (np.abs(times - tmin)).argmin() + imax = (np.abs(times - tmax)).argmin() + df[f"[{tmin}, {tmax}]"] = np.mean(df.iloc[:,imin:imax],axis=1) + + # Create theory predictors dict + predictors = {"iit_predictors": { + f"{tbins[0]}/500ms": "decativated", + f"{tbins[1]}/500ms": "decativated", + f"{tbins[2]}/500ms": "decativated", + f"{tbins[0]}/1000ms": "ativated", + f"{tbins[1]}/1000ms": "decativated", + f"{tbins[2]}/1000ms": "decativated", + f"{tbins[0]}/1500ms": "ativated", + f"{tbins[1]}/1500ms": "ativated", + f"{tbins[2]}/1500ms": "decativated" + }, + "gnw_predictors": { + f"{tbins[0]}/500ms": "ativated", + f"{tbins[1]}/500ms": "decativated", + f"{tbins[2]}/500ms": "decativated", + f"{tbins[0]}/1000ms": "decativated", + f"{tbins[1]}/1000ms": "ativated", + f"{tbins[2]}/1000ms": "decativated", + f"{tbins[0]}/1500ms": "decativated", + f"{tbins[1]}/1500ms": "decativated", + f"{tbins[2]}/1500ms": "ativated" + }} + + # Create LMM models dict + models = create_models() + + # Run analysis + df_all = pd.DataFrame() + for band in ['alpha', 'gamma']: + print('\nfreq_band:', band) + + for label in labels_names: + print('\nlabel:', label) + + # Select band and label + df_cond = df.query(f"band == '{band}' and label == '{label}' and Task_relevance == '{task_rel}'") + df_cond["sub"] = df_cond["sub"].astype(str) + + # Create long table + df_long = pd.melt(df_cond, + id_vars=['sub', 'Category', 'Task_relevance', 'Duration', 'band', 'label'], + value_vars=[str(tbins[0]), str(tbins[1]), str(tbins[2])], + var_name='time_bin') + + # Create theory predictors + data_df = create_theories_predictors(df_long, predictors) + + # Append data to list + df_all = pd.concat([df_all, data_df], ignore_index=True) + + # # Frequency table (used to check the predictors) + # a = pd.crosstab(index=data_df["iit_predictors"], + # columns=[data_df["Duration"], data_df["time_bin"]], + # normalize='columns') + # print(a.iloc[:, 6:9], a.iloc[:, :3], a.iloc[:, 3:6]) + # a = pd.crosstab(index=data_df["gnw_predictors"], + # columns=[data_df["Duration"], data_df["time_bin"]], + # normalize='columns') + # print(a.iloc[:, 6:9], a.iloc[:, :3], a.iloc[:, 3:6]) + + # Fit linear mixed model + results, anova = fit_lmm(data_df, models, re_group='sub') + + # Save LMM results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_lmm", + extension='.tsv', + check=False) + results.to_csv(bids_path_source.fpath, sep="\t", index=False) + + # Save ANOVA results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_anova", + extension='.tsv', + check=False) + anova.to_csv(bids_path_source.fpath, sep="\t", index=False) + + # Compare models + best_models = model_comparison(results, criterion="bic") + + # Save best LMM model results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_best_model", + extension='.tsv', + check=False) + best_models.to_csv(bids_path_source.fpath, sep="\t", index=False) + + + # Plot spectral activation time courses + + # Plot 1a # + # Group by category and duration and averaged across participants + data_m = df_cond.groupby(['Category','Duration'])[df_cond.keys()[1:-8]].mean() + + # Get 95% condidence intervals + data_std = df_cond.groupby(['Category','Duration'])[df_cond.keys()[1:-8]].std() + data_ci = (1.96 * data_std / np.sqrt(len(sub_list))) + + # Cut edges + tmin = (np.abs(times - -0.5)).argmin() + tmax = (np.abs(times - 2.5)).argmin() + t = times[tmin:tmax] + + # Loop over conditions + fig, axs = plt.subplots(4, 1, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + # Get category data + d500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + d500_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + # Cut edges + d500_m = np.squeeze(np.array(d500_m.iloc[:,tmin:tmax])) + d1000_m = np.squeeze(np.array(d1000_m.iloc[:,tmin:tmax])) + d1500_m = np.squeeze(np.array(d1500_m.iloc[:,tmin:tmax])) + + d500_ci = np.squeeze(np.array(d500_ci.iloc[:,tmin:tmax])) + d1000_ci = np.squeeze(np.array(d1000_ci.iloc[:,tmin:tmax])) + d1500_ci = np.squeeze(np.array(d1500_ci.iloc[:,tmin:tmax])) + + # Plot + axs[c].plot(t, np.vstack([d500_m, d1000_m, d1500_m]).transpose(), linewidth=2.0) + + for m, ci in zip([d500_m, d1000_m, d1500_m], + [d500_ci, d1000_ci, d1500_ci]): + axs[c].fill_between(t, m-ci, m+ci, alpha=.2) + + axs[c].set_xlabel('Time (s)', fontsize='x-large') + # axs[c].axhline(y=0, color="black", linestyle="--") + axs[c].axvline(x=0, color="black", linestyle="--") + + for ax in axs.flat: + ax.set_xlim([-.5, 2.4]) + if band == 'alpha': + ax.set_ylim([0.6, 1.4]) + elif band == 'gamma': + ax.set_ylim([0.9, 1.1]) + # ax.axvspan(.3, .5, color='grey', alpha=0.25) + ax.axvspan(tbins[0][0], tbins[0][1], color='red', alpha=0.25) + ax.axvspan(tbins[1][0], tbins[1][1], color='red', alpha=0.25) + ax.axvspan(tbins[2][0], tbins[2][1], color='red', alpha=0.25) + ax.legend(['500ms', '1000ms', '1500ms'], loc='lower left') + + axs[0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + plt.suptitle(f"{band}-band power: time course over {label} source", fontsize='xx-large', fontweight='bold') + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timecourse.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 1b # + # Group by participant and duration and average across categories + data_m = df_cond.groupby(['sub', 'Category', 'Duration'])[df_cond.keys()[1:-8]].mean() + + # Get category data + fig, axs = plt.subplots(4, 3, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + d500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + # Make raster plot + if band == 'alpha': + v = [0.6, 1.4] + elif band == 'gamma': + v = [0.9, 1.1] + + for d, data in zip(range(len(conditions[2])), [d500_m, d1000_m, d1500_m]): + im = axs[c,d].imshow( + data, cmap="RdYlBu_r", + vmin=v[0], vmax=v[1], + origin="lower", aspect="auto", + extent=[times[0], times[-1], len(sub_list), 1]) + axs[c,d].set_xlim([-.5, 2]) + axs[c,d].axvline(x=0, color="black", linestyle="--") + if c == len(conditions[0])-1: + axs[c,d].set_xlabel('Time (s)', fontsize='x-large') + else: + axs[c,d].axes.xaxis.set_ticklabels([]) + if d != 0: + axs[c,d].axes.yaxis.set_ticklabels([]) + + axs[c,0].axvline(x=.5, color="black", linestyle="--") + axs[c,1].axvline(x=1., color="black", linestyle="--") + axs[c,2].axvline(x=1.5, color="black", linestyle="--") + + axs[0,0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1,0].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2,0].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3,0].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.88, 0.15, 0.04, 0.7]) + fig.colorbar(im, cax=cbar_ax, + format=tick.FormatStrFormatter('%.2f')) + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_raster.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 2a # + # Group by duration and average across participants and categories + data_m = df_cond.groupby(['Duration'])[df_cond.keys()[1:-8]].mean() + + # Get 95% condidence intervals + data_std = df_cond.groupby(['Duration'])[df_cond.keys()[1:-8]].std() + data_ci = (1.96 * data_std / np.sqrt(len(sub_list))) + + # Get category data + d500_m = data_m.query("Duration == '500ms'") + d1000_m = data_m.query("Duration == '1000ms'") + d1500_m = data_m.query("Duration == '1500ms'") + + d500_ci = data_ci.query("Duration == '500ms'") + d1000_ci = data_ci.query("Duration == '1000ms'") + d1500_ci = data_ci.query("Duration == '1500ms'") + + # Cut edges + d500_m = np.squeeze(np.array(d500_m.iloc[:,tmin:tmax])) + d1000_m = np.squeeze(np.array(d1000_m.iloc[:,tmin:tmax])) + d1500_m = np.squeeze(np.array(d1500_m.iloc[:,tmin:tmax])) + + d500_ci = np.squeeze(np.array(d500_ci.iloc[:,tmin:tmax])) + d1000_ci = np.squeeze(np.array(d1000_ci.iloc[:,tmin:tmax])) + d1500_ci = np.squeeze(np.array(d1500_ci.iloc[:,tmin:tmax])) + + # Plot + fig, ax = plt.subplots(figsize=(8,6)) + ax.plot(t, np.vstack([d500_m, d1000_m, d1500_m]).transpose(), linewidth=2.0) + + for m, ci in zip([d500_m, d1000_m, d1500_m], + [d500_ci, d1000_ci, d1500_ci]): + ax.fill_between(t, m-ci, m+ci, alpha=.2) + + ax.set_xlabel('Time (s)', fontsize='x-large') + ax.axvline(x=0, color="black", linestyle="--") + + ax.set_xlim([-.5, 2.4]) + if band == 'alpha': + ax.set_ylim([0.6, 1.4]) + elif band == 'gamma': + ax.set_ylim([0.9, 1.1]) + # ax.axvspan(.3, .5, color='grey', alpha=0.25) + ax.axvspan(tbins[0][0], tbins[0][1], color='red', alpha=0.25) + ax.axvspan(tbins[1][0], tbins[1][1], color='red', alpha=0.25) + ax.axvspan(tbins[2][0], tbins[2][1], color='red', alpha=0.25) + ax.legend(['500ms', '1000ms', '1500ms'], loc='lower left') + + ax.set_ylabel('Power (a.u.)', fontsize='x-large') + plt.suptitle(f"{band}-band power: time course over {label} source", fontsize='xx-large', fontweight='bold') + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timecourse_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 2b # + # Group by participant and duration and average across categories + data_m = df_cond.groupby(['sub', 'Duration'])[df_cond.keys()[1:-8]].mean() + + # Get category data + d500_m = data_m.query("Duration == '500ms'") + d1000_m = data_m.query("Duration == '1000ms'") + d1500_m = data_m.query("Duration == '1500ms'") + + # Make raster plot + fig, axs = plt.subplots(3, 1, figsize=[8,6]) + if band == 'alpha': + v = [0.6, 1.4] + elif band == 'gamma': + v = [0.9, 1.1] + + for ax, data in zip(axs.flat, [d500_m, d1000_m, d1500_m]): + im = ax.imshow( + data, cmap="RdYlBu_r", + vmin=v[0], vmax=v[1], + origin="lower", aspect="auto", + extent=[times[0], times[-1], len(sub_list), 1]) + ax.set_xlim([-.5, 2]) + ax.axvline(x=0, color="black", linestyle="--") + + axs[0].axvline(x=.5, color="black", linestyle="--") + axs[1].axvline(x=1., color="black", linestyle="--") + axs[2].axvline(x=1.5, color="black", linestyle="--") + axs[2].set_xlabel('Time (s)', fontsize='x-large') + axs[0].axes.xaxis.set_ticklabels([]) + axs[1].axes.xaxis.set_ticklabels([]) + axs[0].set_ylabel('Participant', fontsize='x-large') + axs[1].set_ylabel('Participant', fontsize='x-large') + axs[2].set_ylabel('Participant', fontsize='x-large') + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.88, 0.15, 0.04, 0.7]) + fig.colorbar(im, cax=cbar_ax, + format=tick.FormatStrFormatter('%.2f')) + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_raster_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot spectral activation raincloud + + # Get indivisual data by condition + data_sub_m = df_long.groupby(['sub','Category','Duration','time_bin'],as_index = False)["value"].mean() + + # Fix order of levels in duration variable + durations = ['500ms', '1000ms', '1500ms'] + data_sub_m['Duration'] = pd.Categorical( + data_sub_m['Duration'], + categories=durations, + ordered=True) + + # Loop over categories + fig, axs = plt.subplots(4, 3, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + # Get data + d_m = data_sub_m.query(f"Category =='{conditions[0][c]}'") + + for d in range(len(durations)): + print("duration:",durations[d]) + + # Plot violin + pt.half_violinplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + bw = .2, cut = 0., + scale = "area", width = .6, + inner = None, + ax = axs[c,d]) + + # Add points + sns.stripplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + edgecolor = "white", + size = 3, jitter = 1, zorder = 0, + ax = axs[c,d]) + + # Add boxplot + sns.boxplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + color = "black", width = .15, zorder = 10, + showcaps = True, boxprops = {'facecolor':'none', "zorder":10},\ + showfliers=True, whiskerprops = {'linewidth':2, "zorder":10},\ + saturation = 1, + ax = axs[c,d]) + + for ax in axs.flat: + if band == 'alpha': + ax.set_ylim([0.65, 1.35]) + elif band == 'gamma': + ax.set_ylim([0.9, 1.1]) + + axs[0,0].set_xlabel(None) + axs[0,1].set_xlabel(None) + axs[0,2].set_xlabel(None) + axs[1,0].set_xlabel(None) + axs[1,1].set_xlabel(None) + axs[1,2].set_xlabel(None) + axs[2,0].set_xlabel(None) + axs[2,1].set_xlabel(None) + axs[2,2].set_xlabel(None) + + axs[3,0].set_xlabel(f'{durations[0]} duration', fontsize='x-large', fontweight='bold') + axs[3,1].set_xlabel(f'{durations[1]} duration', fontsize='x-large', fontweight='bold') + axs[3,2].set_xlabel(f'{durations[1]} duration', fontsize='x-large', fontweight='bold') + + axs[0,0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1,0].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2,0].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3,0].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + plt.suptitle(f"{band}-band power: {label}", fontsize='xx-large', fontweight='bold') + + plt.tight_layout() + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timebins.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + # Plot 2 # + # Get indivisual data by condition averaged across categories + data_sub_m = df_long.groupby(['sub','Duration','time_bin'],as_index = False)["value"].mean() + + # Fix order of levels in duration variable + durations = ['500ms', '1000ms', '1500ms'] + data_sub_m['Duration'] = pd.Categorical( + data_sub_m['Duration'], + categories=['500ms', '1000ms', '1500ms'], + ordered=True) + + # Create subplot + fig, axs = plt.subplots(1,3, figsize=(8,6)) + + # Loop over durations + for d in range(len(durations)): + print("duration:",durations[d]) + + # Plot violin + pt.half_violinplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + bw = .2, cut = 0., + scale = "area", width = .6, + inner = None, + ax = axs[d]) + + # Add points + sns.stripplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + edgecolor = "white", + size = 3, jitter = 1, zorder = 0, + ax = axs[d]) + + # Add boxplot + sns.boxplot( + x = "time_bin", y = "value", + data = d_m.query(f"Duration =='{durations[d]}'"), + color = "black", width = .15, zorder = 10, + showcaps = True, boxprops = {'facecolor':'none', "zorder":10},\ + showfliers=True, whiskerprops = {'linewidth':2, "zorder":10},\ + saturation = 1, + ax = axs[d]) + + for ax in axs.flat: + if band == 'alpha': + ax.set_ylim([0.65, 1.35]) + elif band == 'gamma': + ax.set_ylim([0.9, 1.1]) + + axs[0].set_ylabel('Power (a.u.)', fontsize='x-large') + + axs[0].set_xlabel(f'{durations[0]} duration', fontsize='x-large', fontweight='bold') + axs[1].set_xlabel(f'{durations[1]} duration', fontsize='x-large', fontweight='bold') + axs[2].set_xlabel(f'{durations[1]} duration', fontsize='x-large', fontweight='bold') + + plt.suptitle(f"{band}-band power: {label}", fontsize='xx-large', fontweight='bold') + + plt.tight_layout() + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timebins_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + # Save table as .tsv + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + subject=f"groupphase{phase}", + suffix=f"{tbins[0]}_{task_rel[:3]}_lmm_datatable", + check=False) + df_all.to_csv(bids_path_source.fpath, + sep="\t", + index=False) + + +def create_theories_predictors(df, predictors_mapping): + """ + This function adds predictors to the data frame based on the predictor mapping passed. The passed predictors + consist of dictionaries, providing mapping between specific experimental condition and a specific value to give it. + This function therefore loops through each of the predictor and through each of the condition of that predictor. + It will then look for the condition combination matching it to attribute it the value the predictor dictates. + Example: one predictor states: faces/short= 1, faces/intermediate=0... This is what that function does + DISCLAIMER: I know using groupby would be computationally more efficient, but this makes for more readable and easy + to encode the predictors, so I went this way. + :param df: (data frame) data frame to add the predictors to + :param predictors_mapping: (dict of dict) One dictionary per predictor. For each predictor, one dictionary + containing + mapping between condition combination and value to attribute to it + :return: (dataframe) the data frame that was fed in + predictors values + """ + print("-" * 40) + print("Creating theories' derived predictors: ") + # Getting the name of the columns which are not the ones automatically ouputed by mne, because these are the ones + # we created and that contain the info we seek: + col_list = [col for col in df.columns if col not in [ + "epoch", "channel", "value", "condition"]] + # Looping through the predictors: + for predictor in predictors_mapping.keys(): + df[predictor] = np.nan + # Now looping through the key of each predictor, as this contains the mapping info: + for key in predictors_mapping[predictor].keys(): + # Finding the index of each row matching the key: + bool_list = \ + [all(x in list(trial_info[col_list].values) + for x in key.split("/")) + for ind, trial_info in df.iterrows()] + # Using the boolean list to add the value of the predictor in the concerned row: + df.loc[bool_list, predictor] = predictors_mapping[predictor][key] + + return df + + +def create_models(package="lmer"): + if package == "stats_model": + models = { "null_model": { + "model": "value ~ 1", + "re_formula": None + }, + "time_win": { + "model": "value ~ time_bin", + "re_formula": None + }, + "duration": { + "model": "value ~ Duration", + "re_formula": None + }, + "time_win_dur": { + "model": "value ~ time_bin + Duration", + "re_formula": None + }, + "time_win_dur_iit": { + "model": "value ~ time_bin + Duration + iit_predictors", + "re_formula": None + }, + "time_win_dur_gnw": { + "model": "value ~ time_bin + Duration + gnw_predictors", + "re_formula": None + }, + "time_win_dur_cate_iit": { + "model": "value ~ time_bin + Duration + Category*iit_predictors", + "re_formula": None + }, + "time_win_dur_cate_gnw": { + "model": "value ~ time_bin + Duration + Category*gnw_predictors", + "re_formula": None + }} + elif package == "lmer": + models = { + "null_model": { + "model": "value ~ 1 + (1|sub)", + "re_formula": None + }, + "time_win": { + "model": "value ~ time_bin + (1|sub)", + "re_formula": None + }, + "duration": { + "model": "value ~ Duration + (1|sub)", + "re_formula": None + }, + "time_win_dur": { + "model": "value ~ time_bin + Duration + (1|sub)", + "re_formula": None + }, + "time_win_dur_iit": { + "model": "value ~ time_bin + Duration + iit_predictors + (1|sub)", + "re_formula": None + }, + "time_win_dur_gnw": { + "model": "value ~ time_bin + Duration + gnw_predictors + (1|sub)", + "re_formula": None + }, + "time_win_dur_cate_iit": { + "model": "value ~ time_bin + Duration + Category * iit_predictors + (1|sub)", + "re_formula": None + }, + "time_win_dur_cate_gnw": { + "model": "value ~ time_bin + Duration + Category * gnw_predictors + (1|sub)", + "re_formula": None + }} + return models + + +def fit_lmm(data, models, re_group, group="", alpha=0.05, package="lmer"): + """ + This function fits the different linear mixed models passed in the model dict on the data + :param data: (pandas data frame) contains the data to fit the linear mixed model on + :param models: (dict) contains the different models: + "null_model": { + "model": "value ~ 1", + "re_formula": null + }, + "time_win": { + "model": "value ~ time_bin", + "re_formula": null + }, + "duration": { + "model": "value ~ duration", + "re_formula": null + }, + the key of each is the name of the model (used to identify it down the line), the model is the formula, the + re_formula is for the random slopes + :param re_group: (string) name of the random effect group. If you have measure repeated within trials, this should + be trial for example + :param group: (string) name of the column from the data table that corresponds to the groups for which to run the + model separately. You can run it on single channels, in which case group must be "channel" + :param alpha: (float) alpha to consider significance. Not really used + :return: + """ + print("-" * 40) + print("Welcome to fit_lmm") + results = pd.DataFrame() + anova_results = pd.DataFrame() + # Looping through the different models to apply to the data of that particular channel: + for model in models.keys(): + print(model) + if package == "stats_model": + print("Fitting {} model to group {}".format(model, group)) + # Applying the linear mixed model specified in the parameters: + md = smf.mixedlm(models[model]["model"], + data, groups=re_group, re_formula=models[model]["re_formula"]) + # Fitting the model: + mdf = md.fit(reml=False) + # Printing the summary in the command line: + print(mdf.summary()) + # Compute the r2: + # r2 = compute_lmm_r2(mdf) + # Extracting the results and storing them to the dataframe: + results = pd.concat([results, + pd.DataFrame({ + "subject": group.split("-")[0], + "analysis_name": ["linear_mixed_model"] * len(mdf.pvalues), + "model": [model] * len(mdf.pvalues), + "group": [group] * len(mdf.pvalues), + "coefficient-conditions": mdf.params.index.values, + "Coef.": mdf.params.values, + "Std.Err.": mdf.bse.values, + "z": mdf.tvalues.values, + "p-value": mdf.pvalues.values, + "reject": [True if p_val < alpha else False for p_val in mdf.pvalues.values], + "converged": [mdf.converged] * len(mdf.pvalues), + "log_likelyhood": [mdf.llf] * len(mdf.pvalues), + "aic": [mdf.aic] * len(mdf.pvalues), + "bic": [mdf.bic] * len(mdf.pvalues) + })], ignore_index=True) + elif package == "lmer": + # Fit the model: + mdl = Lmer(models[model]["model"], data=data) + print(mdl.fit(REML=False)) + # Append the coefs to the results table: + coefs = mdl.coefs + results = pd.concat([results, + pd.DataFrame({ + "subject": group.split("-")[0], + "analysis_name": ["linear_mixed_model"] * len(coefs["Estimate"]), + "model": [model] * len(coefs["Estimate"]), + "group": [group] * len(coefs["Estimate"]), + "coefficient-conditions": coefs.index.values, + "Coef.": coefs["Estimate"].to_list(), + "T-stat": coefs["T-stat"].to_list(), + "p-value": coefs["P-val"].to_list(), + "reject": [True if p_val < alpha else False for p_val in coefs["P-val"].to_list()], + "converged": [True] * len(coefs["Estimate"]), + "log_likelyhood": [mdl.logLike] * len(coefs["Estimate"]), + "aic": [mdl.AIC] * len(coefs["Estimate"]), + "bic": [mdl.BIC] * len(coefs["Estimate"]) + })], ignore_index=True) + + # In addition, run the anova on the model to extract the main effects: + anova_res = mdl.anova() + # For the null model, since there are no main effects, the anova results are empty: + if len(anova_res) == 0: + anova_results = pd.concat([anova_results, pd.DataFrame({ + "subject": group.split("-")[0], + "analysis_name": "anova", + "model": model, + "group": group, + "conditions": np.nan, + "F-stat": np.nan, + "NumDF": np.nan, + "DenomDF": np.nan, + "p-value": np.nan, + "reject": np.nan, + "converged": [True] * len(coefs["Estimate"]), + "SS": np.nan, + "aic": mdl.AIC, + "bic": mdl.BIC + }, index=[0])], ignore_index=True) + else: + anova_results = pd.concat([anova_results, pd.DataFrame({ + "subject": group.split("-")[0], + "analysis_name": ["anova"] * len(anova_res), + "model": [model] * len(anova_res), + "group": [group] * len(anova_res), + "conditions": anova_res.index.values, + "F-stat": anova_res["F-stat"].to_list(), + "NumDF": anova_res["NumDF"].to_list(), + "DenomDF": anova_res["DenomDF"].to_list(), + "p-value": anova_res["P-val"].to_list(), + "reject": [True if p_val < alpha else False for p_val in anova_res["P-val"].to_list()], + "converged": [True] * len(anova_res), + "SS": anova_res["SS"].to_list(), + "aic": [mdl.AIC] * len(anova_res), + "bic": [mdl.BIC] * len(anova_res) + })], ignore_index=True) + + return results, anova_results + + +def model_comparison(models_results, criterion="bic", test="linear_mixed_model"): + """ + The model results contain columns for fit criterion (log_likelyhood, aic, bic) that can be used to investigate + which model had the best fit. Indeed, because we are trying several models, if several of them are found to be + significant in the coefficients of interest, we need to arbitrate betweem them. This function does that by + looping through each channel and checking whether or not more than one model was found signficant. If yes, + then the best one is selected by checking the passed criterion. The criterion must match the string of one of the + column of the models_results dataframe + :param models_results: (dataframe) results of the linear mixed models, as returned by the fit_single_channels_lmm + function + :param criterion: (string) name of the criterion to use to arbitrate between models + :param test: (string) type of test (i.e. model that was preprocessing) + example, if you have ran several models per channels, pass here channels and it will look separately at each channel + :return: + best_models (pandas data frame) contains the results of the best models only, i.e. one model per channel only + """ + print("-" * 40) + print("Welcome to model comparison") + print("Comparing the fitted {} using {}".format(test, criterion)) + # Declare dataframe to store the best models only: + best_models = pd.DataFrame(columns=models_results.columns.values.tolist()) + # Removing any model that didn't converge in case a linear mixed model was used: + if test == "linear_mixed_model": + converge_models_results = models_results.loc[models_results["converged"]] + else: + converge_models_results = models_results + # In the linear mixed model function used before, the fit criterion are an extra column. Therefore, for a given + # electrode, the best fit is any row of the table that has the max of the criterion. Therefore, looping over + # the data: + for channel in converge_models_results["group"].unique(): + # Getting the results for that channel only + data = converge_models_results.loc[converge_models_results["group"] == channel] + # Extracting the rows with highest criterion + best_model = data.loc[data[criterion] == np.nanmin(data[criterion])] + # Adding it to the best_models dataframe, storing all the best models: + best_models = pd.concat([best_models, best_model], ignore_index=True) + return best_models + + +if __name__ == '__main__': + for task_rel in ['Irrelevant', 'Relevant non-target']: + for tbins in [ [[0.8,1.0],[1.3,1.5],[1.8,2.0]], [[1.0,1.2],[1.5,1.7],[2.0,2.2]] ]: + run_source_dur_activation(task_rel, tbins) diff --git a/activation/S05b_source_dur_erf_lmm.py b/activation/S05b_source_dur_erf_lmm.py new file mode 100644 index 0000000..e90f8b5 --- /dev/null +++ b/activation/S05b_source_dur_erf_lmm.py @@ -0,0 +1,606 @@ +""" +======================================= +S06. Source spectal activation analysis +======================================= + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.ticker as tick +import pandas as pd +import seaborn as sns +import ptitprince as pt #conda install -c conda-forge ptitprince + +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from S06_source_dur_lmm import (create_theories_predictors, fit_lmm, + model_comparison, create_models) + +# Set params +visit_id = "V1" + +task_rel = 'Irrelevant' +# task_rel = 'Relevant non-target' +tbins = [[0.8,1.0],[1.3,1.5],[1.8,2.0]] + +debug = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant target','Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def run_source_dur_activation(): + # Set directory paths + source_deriv_root = op.join(bids_root, "derivatives", "source_dur_ERF") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read list with label names + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root[:-10], + subject=sub_list[0], + datatype='meg', + task=bids_task, + session=visit_id, + suffix="desc-labels", + extension='.txt', + check=False) + labels_names = open(bids_path_source.fpath, 'r').read() + labels_names = labels_names[2:-2].split("', '") + + # Read dataframe + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=f"groupphase{phase}", + datatype='meg', + task=bids_task, + session=visit_id, + suffix="datatable", + extension='.tsv', + check=False) + df = pd.read_csv(bids_path_source.fpath, sep="\t") + + # Rename "False" as "false" + df.loc[df['Category']==False, 'Category'] = 'false' + + # Average data in the three time window of interest + times = np.array([float(t) for t in list(df.columns)[1:-5]]) + for tmin, tmax in tbins: + imin = (np.abs(times - tmin)).argmin() + imax = (np.abs(times - tmax)).argmin() + df[f"[{tmin}, {tmax}]"] = np.mean(df.iloc[:,imin:imax],axis=1) + + # Create theory predictors dict + predictors = {"iit_predictors": { + f"{tbins[0]}/500ms": "decativated", + f"{tbins[1]}/500ms": "decativated", + f"{tbins[2]}/500ms": "decativated", + f"{tbins[0]}/1000ms": "ativated", + f"{tbins[1]}/1000ms": "decativated", + f"{tbins[2]}/1000ms": "decativated", + f"{tbins[0]}/1500ms": "ativated", + f"{tbins[1]}/1500ms": "ativated", + f"{tbins[2]}/1500ms": "decativated" + }, + "gnw_predictors": { + f"{tbins[0]}/500ms": "ativated", + f"{tbins[1]}/500ms": "decativated", + f"{tbins[2]}/500ms": "decativated", + f"{tbins[0]}/1000ms": "decativated", + f"{tbins[1]}/1000ms": "ativated", + f"{tbins[2]}/1000ms": "decativated", + f"{tbins[0]}/1500ms": "decativated", + f"{tbins[1]}/1500ms": "decativated", + f"{tbins[2]}/1500ms": "ativated" + }} + + # Create LMM models dict + models = create_models() + + # Run analysis + df_all = pd.DataFrame() + for band in ['ERF']: + print('\nfreq_band:', band) + + for label in labels_names: + print('\nlabel:', label) + + # Select band and label + df_cond = df.query(f"band == '{band}' and label == '{label}' and Task_relevance == '{task_rel}'") + + # Create long table + df_long = pd.melt(df_cond, + id_vars=['sub', 'Category', 'Task_relevance', 'Duration', 'band', 'label'], + value_vars=[str(tbins[0]), str(tbins[1]), str(tbins[2])], + var_name='time_bin') + + # Create theory predictors + data_df = create_theories_predictors(df_long, predictors) + + # Append data to list + df_all = pd.concat([df_all, data_df], ignore_index=True) + + # # Frequency table (used to check the predictors) + # pd.crosstab(index=data_df["iit_predictors"], + # columns=[data_df["Duration"], data_df["time_bin"]], + # normalize='columns') + # pd.crosstab(index=data_df["gnw_predictors"], + # columns=[data_df["Duration"], data_df["time_bin"]]) + + # Fit linear mixed model + results, anova = fit_lmm(data_df, models, re_group='sub') + + # Save LMM results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_lmm", + extension='.tsv', + check=False) + results.to_csv(bids_path_source.fpath, sep="\t", index=False) + + # Save ANOVA results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_anova", + extension='.tsv', + check=False) + anova.to_csv(bids_path_source.fpath, sep="\t", index=False) + + # Compare models + best_models = model_comparison(results, criterion="bic") + + # Save best LMM model results + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + suffix=f"desc-{band},{label},{tbins[0]},{task_rel[:3]}_best_model", + extension='.tsv', + check=False) + best_models.to_csv(bids_path_source.fpath, sep="\t", index=False) + + + # Plot ERF time courses + + # Plot 1a # + # Group by category and duration and average across participants + data_m = df_cond.groupby(['Category','Duration'])[df_cond.keys()[1:-8]].mean() + + # Get 95% condidence intervals + data_std = df_cond.groupby(['Category','Duration'])[df_cond.keys()[1:-8]].std() + data_ci = (1.96 * data_std / np.sqrt(len(sub_list))) + + # Cut edges + tmin = (np.abs(times - -0.5)).argmin() + tmax = (np.abs(times - 2.5)).argmin() + t = times[tmin:tmax] + + # Loop over conditions + fig, axs = plt.subplots(4, 1, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + # Get category data + d500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + d500_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_ci = data_ci.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + # Cut edges + d500_m = np.squeeze(np.array(d500_m.iloc[:,tmin:tmax])) + d1000_m = np.squeeze(np.array(d1000_m.iloc[:,tmin:tmax])) + d1500_m = np.squeeze(np.array(d1500_m.iloc[:,tmin:tmax])) + + d500_ci = np.squeeze(np.array(d500_ci.iloc[:,tmin:tmax])) + d1000_ci = np.squeeze(np.array(d1000_ci.iloc[:,tmin:tmax])) + d1500_ci = np.squeeze(np.array(d1500_ci.iloc[:,tmin:tmax])) + + # Plot + axs[c].plot(t, np.vstack([d500_m, d1000_m, d1500_m]).transpose(), linewidth=2.0) + + for m, ci in zip([d500_m, d1000_m, d1500_m], + [d500_ci, d1000_ci, d1500_ci]): + axs[c].fill_between(t, m-ci, m+ci, alpha=.2) + + axs[c].set_xlabel('Time (s)', fontsize='x-large') + # axs[c].axhline(y=0, color="black", linestyle="--") + axs[c].axvline(x=0, color="black", linestyle="--") + + for ax in axs.flat: + # ax.set_xlim([-.5, 2.4]) + # if band == 'alpha': + # ax.set_ylim([0.6, 1.4]) + # elif band == 'gamma': + # ax.set_ylim([0.9, 1.1]) + # ax.axvspan(.3, .5, color='grey', alpha=0.25) + ax.axvspan(tbins[0][0], tbins[0][1], color='red', alpha=0.25) + ax.axvspan(tbins[1][0], tbins[1][1], color='red', alpha=0.25) + ax.axvspan(tbins[2][0], tbins[2][1], color='red', alpha=0.25) + ax.legend(['500ms', '1000ms', '1500ms'], loc='lower left') + + axs[0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + plt.suptitle(f"{band}: time course over {label} source", fontsize='xx-large', fontweight='bold') + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timecourse.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 1b # + # Group by participant and duration and average across categories + data_m = df_cond.groupby(['sub', 'Category', 'Duration'])[df_cond.keys()[1:-8]].mean() + + # Get category data + fig, axs = plt.subplots(4, 3, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + d500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '500ms'") + d1000_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1000ms'") + d1500_m = data_m.query(f"Category =='{conditions[0][c]}' and \ + Duration == '1500ms'") + + # Make raster plot + # if band == 'alpha': + # v = [0.6, 1.4] + # elif band == 'gamma': + # v = [0.9, 1.1] + + for d, data in zip(range(len(conditions[2])), [d500_m, d1000_m, d1500_m]): + im = axs[c,d].imshow( + data, cmap="RdYlBu_r", + # vmin=v[0], vmax=v[1], + origin="lower", aspect="auto", + extent=[times[0], times[-1], len(sub_list), 1]) + axs[c,d].set_xlim([-.5, 2]) + axs[c,d].axvline(x=0, color="black", linestyle="--") + if c == len(conditions[0])-1: + axs[c,d].set_xlabel('Time (s)', fontsize='x-large') + else: + axs[c,d].axes.xaxis.set_ticklabels([]) + if d != 0: + axs[c,d].axes.yaxis.set_ticklabels([]) + + axs[c,0].axvline(x=.5, color="black", linestyle="--") + axs[c,1].axvline(x=1., color="black", linestyle="--") + axs[c,2].axvline(x=1.5, color="black", linestyle="--") + + axs[0,0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1,0].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2,0].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3,0].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.88, 0.15, 0.04, 0.7]) + fig.colorbar(im, cax=cbar_ax, + format=tick.FormatStrFormatter('%.2f')) + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_raster.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 2a # + # Group by duration and average across participants and categories + data_m = df_cond.groupby(['Duration'])[df_cond.keys()[1:-8]].mean() + + # Get 95% condidence intervals + data_std = df_cond.groupby(['Duration'])[df_cond.keys()[1:-8]].std() + data_ci = (1.96 * data_std / np.sqrt(len(sub_list))) + + # Get category data + d500_m = data_m.query("Duration == '500ms'") + d1000_m = data_m.query("Duration == '1000ms'") + d1500_m = data_m.query("Duration == '1500ms'") + + d500_ci = data_ci.query("Duration == '500ms'") + d1000_ci = data_ci.query("Duration == '1000ms'") + d1500_ci = data_ci.query("Duration == '1500ms'") + + # Cut edges + d500_m = np.squeeze(np.array(d500_m.iloc[:,tmin:tmax])) + d1000_m = np.squeeze(np.array(d1000_m.iloc[:,tmin:tmax])) + d1500_m = np.squeeze(np.array(d1500_m.iloc[:,tmin:tmax])) + + d500_ci = np.squeeze(np.array(d500_ci.iloc[:,tmin:tmax])) + d1000_ci = np.squeeze(np.array(d1000_ci.iloc[:,tmin:tmax])) + d1500_ci = np.squeeze(np.array(d1500_ci.iloc[:,tmin:tmax])) + + # Plot + fig, ax = plt.subplots(figsize=(8,6)) + ax.plot(t, np.vstack([d500_m, d1000_m, d1500_m]).transpose(), linewidth=2.0) + + for m, ci in zip([d500_m, d1000_m, d1500_m], + [d500_ci, d1000_ci, d1500_ci]): + ax.fill_between(t, m-ci, m+ci, alpha=.2) + + ax.set_xlabel('Time (s)', fontsize='x-large') + ax.axvline(x=0, color="black", linestyle="--") + + ax.set_xlim([-.5, 2.4]) + # if band == 'alpha': + # ax.set_ylim([0.6, 1.4]) + # elif band == 'gamma': + # ax.set_ylim([0.9, 1.1]) + # ax.axvspan(.3, .5, color='grey', alpha=0.25) + ax.axvspan(tbins[0][0], tbins[0][1], color='red', alpha=0.25) + ax.axvspan(tbins[1][0], tbins[1][1], color='red', alpha=0.25) + ax.axvspan(tbins[2][0], tbins[2][1], color='red', alpha=0.25) + ax.legend(['500ms', '1000ms', '1500ms'], loc='lower left') + + ax.set_ylabel('Activation (rms)', fontsize='x-large') + plt.suptitle(f"{band}: time course over {label} source", fontsize='xx-large', fontweight='bold') + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timecourse_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot 2b # + # Group by participant and duration and average across categories + data_m = df_cond.groupby(['sub', 'Duration'])[df_cond.keys()[1:-8]].mean() + + # Get category data + d500_m = data_m.query("Duration == '500ms'") + d1000_m = data_m.query("Duration == '1000ms'") + d1500_m = data_m.query("Duration == '1500ms'") + + # Make raster plot + fig, axs = plt.subplots(3, 1, figsize=[8,6]) + # if band == 'alpha': + # v = [0.6, 1.4] + # elif band == 'gamma': + # v = [0.9, 1.1] + + for ax, data in zip(axs.flat, [d500_m, d1000_m, d1500_m]): + im = ax.imshow( + data, cmap="RdYlBu_r", + # vmin=v[0], vmax=v[1], + origin="lower", aspect="auto", + extent=[times[0], times[-1], len(sub_list), 1]) + ax.set_xlim([-.5, 2]) + ax.axvline(x=0, color="black", linestyle="--") + + axs[0].axvline(x=.5, color="black", linestyle="--") + axs[1].axvline(x=1., color="black", linestyle="--") + axs[2].axvline(x=1.5, color="black", linestyle="--") + axs[2].set_xlabel('Time (s)', fontsize='x-large') + axs[0].axes.xaxis.set_ticklabels([]) + axs[1].axes.xaxis.set_ticklabels([]) + axs[0].set_ylabel('Participant', fontsize='x-large') + axs[1].set_ylabel('Participant', fontsize='x-large') + axs[2].set_ylabel('Participant', fontsize='x-large') + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.88, 0.15, 0.04, 0.7]) + fig.colorbar(im, cax=cbar_ax, + format=tick.FormatStrFormatter('%.2f')) + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_raster_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Plot ERF activation raincloud + + # Get indivisual data by condition + data_sub_m = df_long.groupby(['sub','Category','Duration','time_bin'],as_index = False)["value"].mean() + + # Fix order of levels in duration variable + data_sub_m['Duration'] = pd.Categorical( + data_sub_m['Duration'], + categories=['500ms', '1000ms', '1500ms'], + ordered=True) + + # Loop over categories + fig, axs = plt.subplots(4, 3, figsize=(8,8)) + for c in range(len(conditions[0])): + print("condition:",conditions[0][c]) + + # Get data + d_m = data_sub_m.query(f"Category =='{conditions[0][c]}'") + + for d in range(len(tbins)): + print("time bin:",tbins[d]) + + # Plot violin + pt.half_violinplot( + x = "Duration", y = "value", + data = d_m.query(f"time_bin =='{tbins[d]}'"), + bw = .2, cut = 0., + scale = "area", width = .6, + inner = None, + ax = axs[c,d]) + + # Add points + sns.stripplot( + x = "Duration", y = "value", + data = d_m.query(f"time_bin =='{tbins[d]}'"), + edgecolor = "white", + size = 3, jitter = 1, zorder = 0, + ax = axs[c,d]) + + # Add boxplot + sns.boxplot( + x = "Duration", y = "value", + data = d_m.query(f"time_bin =='{tbins[d]}'"), + color = "black", width = .15, zorder = 10, + showcaps = True, boxprops = {'facecolor':'none', "zorder":10},\ + showfliers=True, whiskerprops = {'linewidth':2, "zorder":10},\ + saturation = 1, + ax = axs[c,d]) + + # for ax in axs.flat: + # if band == 'alpha': + # ax.set_ylim([0.65, 1.35]) + # elif band == 'gamma': + # ax.set_ylim([0.9, 1.1]) + + axs[0,0].set_xlabel(None) + axs[0,1].set_xlabel(None) + axs[0,2].set_xlabel(None) + axs[1,0].set_xlabel(None) + axs[1,1].set_xlabel(None) + axs[1,2].set_xlabel(None) + axs[2,0].set_xlabel(None) + axs[2,1].set_xlabel(None) + axs[2,2].set_xlabel(None) + + axs[3,0].set_xlabel(f'{tbins[0]} time bin', fontsize='x-large', fontweight='bold') + axs[3,1].set_xlabel(f'{tbins[1]} time bin', fontsize='x-large', fontweight='bold') + axs[3,2].set_xlabel(f'{tbins[1]} time bin', fontsize='x-large', fontweight='bold') + + axs[0,0].set_ylabel('Face', fontsize='x-large', fontweight='bold') + axs[1,0].set_ylabel('Object', fontsize='x-large', fontweight='bold') + axs[2,0].set_ylabel('Letter', fontsize='x-large', fontweight='bold') + axs[3,0].set_ylabel('False-font', fontsize='x-large', fontweight='bold') + plt.suptitle(f"{band}: time bins over {label} source", fontsize='xx-large', fontweight='bold') + + plt.tight_layout() + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timebins.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + # Plot 2 # + # Get indivisual data by condition averaged across categories + data_sub_m = df_long.groupby(['sub','Duration','time_bin'],as_index = False)["value"].mean() + + # Fix order of levels in duration variable + data_sub_m['Duration'] = pd.Categorical( + data_sub_m['Duration'], + categories=['500ms', '1000ms', '1500ms'], + ordered=True) + + # Create subplot + fig, axs = plt.subplots(1,3, figsize=(8,6)) + + # Loop over durations + for d in range(len(tbins)): + print("time bin:",tbins[d]) + + # Plot violin + pt.half_violinplot( + x = "Duration", y = "value", + data = data_sub_m.query(f"time_bin =='{tbins[d]}'"), + bw = .2, cut = 0., + scale = "area", width = .6, + inner = None, + ax = axs[d]) + + # Add points + sns.stripplot( + x = "Duration", y = "value", + data = data_sub_m.query(f"time_bin =='{tbins[d]}'"), + edgecolor = "white", + size = 3, jitter = 1, zorder = 0, + ax = axs[d]) + + # Add boxplot + sns.boxplot( + x = "Duration", y = "value", + data = data_sub_m.query(f"time_bin =='{tbins[d]}'"), + color = "black", width = .15, zorder = 10, + showcaps = True, boxprops = {'facecolor':'none', "zorder":10},\ + showfliers=True, whiskerprops = {'linewidth':2, "zorder":10},\ + saturation = 1, + ax = axs[d]) + + # for ax in axs.flat: + # if band == 'alpha': + # ax.set_ylim([0.65, 1.35]) + # elif band == 'gamma': + # ax.set_ylim([0.9, 1.1]) + + axs[0].set_ylabel('Activaiton (rms)', fontsize='x-large') + + axs[0].set_xlabel('0.8-1.0 time bin', fontsize='x-large') + axs[1].set_xlabel('1.3-1.5 time bin', fontsize='x-large') + axs[2].set_xlabel('1.8-2.0 time bin', fontsize='x-large') + + plt.suptitle(f"{band}: time bins over {label} source", fontsize='xx-large', fontweight='bold') + + plt.tight_layout() + + # Save figure + fname_fig = op.join(source_figure_root, + f"sourcedur-{band}_{label}_{tbins[0]}_{task_rel[:3]}_timebins_avg.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + # Save table as .tsv + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + subject=f"groupphase{phase}", + suffix=f"{tbins[0]}_{task_rel[:3]}_lmm_datatable", + check=False) + df_all.to_csv(bids_path_source.fpath, + sep="\t", + index=False) + + +if __name__ == '__main__': + run_source_dur_activation() diff --git a/activation/S06_source_dur_onsetoffset_control.py b/activation/S06_source_dur_onsetoffset_control.py new file mode 100644 index 0000000..6ddc53f --- /dev/null +++ b/activation/S06_source_dur_onsetoffset_control.py @@ -0,0 +1,349 @@ +""" +================ +S05. Grand-average source epochs +================ + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import numpy as np +# import matplotlib.pyplot as plt +import argparse +import pandas as pd +from statsmodels.stats.multitest import multipletests +from scipy.stats import ttest_1samp, wilcoxon + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='dspm', + help='method used for the inverse solution') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +inv_method = opt.method +visit_id = "V1" + +debug = False + + +factor = ['Category', 'Task_relevance', "Duration"] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant non-target','Irrelevant'], + ['500ms', '1000ms', '1500ms']] + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA126"] +elif phase == 2: + sub_list = ["SA106", "SA107", "SA109", "SA112", "SA113", "SA116", "SA124", + "SA126", "SA127", "SA128", "SA131", + "SA110", "SA142", "SA152", "SA160", "SA172", + "SB002", "SB015", "SB022", "SB030" ,"SB038", "SB041", "SB061", + "SB065", "SB071", "SB073", "SB085", + "SB013", "SB024", "SB045", "SB050", "SB078" + ] +elif phase == 3: + # Read the .txt list + f = open(op.join(bids_root, + 'participants_MEG_phase3_included.txt'), 'r').read() + # Add spaces and split elements in the list + sub_list = f.replace('\n', ' ').split(" ") + + +def source_dur_ga(): + # Set directory paths + source_deriv_root = op.join(bids_root, "derivatives", "source_dur") + if not op.exists(source_deriv_root): + os.makedirs(source_deriv_root) + source_figure_root = op.join(source_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(source_figure_root): + os.makedirs(source_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read the group data table + bids_path_source = mne_bids.BIDSPath( + root=source_deriv_root, + subject=f"groupphase{phase}", + datatype="meg", + task=bids_task, + session=visit_id, + suffix="datatable", + extension=".tsv", + check=False) + + df = pd.read_csv(bids_path_source.fpath, sep="\t") + + # Move power values to a single column + df['values'] = [np.array(df.iloc[i,1:-5]) for i in range(len(df))] + + # Drop the columns left + df = df.drop(df.columns[1:-6],axis=1) + + # Select task-irrelevant trials only + df = df[df['Task_relevance'] == 'Irrelevant'] + + # Create info + info = mne.create_info( + ch_names=['gnw_all', 'iit_all'], + sfreq=1000) + + # Create empy data frame + results = pd.DataFrame() + + # Loop over analysis + for analysis in ['onset', 'offset']: + + # Loop over freq bands + for band in ['alpha', 'gamma']: + print('\nfreq_band:', band) + + # Create empty list + all_df = np.empty((len(sub_list),2,(3501))) + + # Loop over labels + for i, label in enumerate(['gnw_all', 'iit_all']): + print('\nlabel:', label) + + # Get data for given conditions + data = df[(df['band'] == band) & (df['label'] == label)] + + # If offset analysis, lock data to stim offset + if analysis == 'offset': + for dur in np.unique(data["Duration"]): + data.loc[data['Duration'] == dur, "values"] = \ + data.loc[data['Duration'] == dur, "values"].apply( + lambda temp: np.concatenate( + [temp[int(dur[:-2]):],temp[:int(dur[:-2])]])) + + # Average across conditions + data = data.groupby( + ['sub','band','label'])['values'].apply( + np.mean,0).to_frame().reset_index() + + # Append data to group array + all_df[:,i,:] = np.stack(data['values']) + + # Empty list + data_df = [] + + # Loop across subjects + for i, sub in enumerate(sub_list): + print('\nsubject:', sub) + + # Create epoch object + epochs = mne.EpochsArray(all_df[np.newaxis,i,:,:], + info, + tmin=-1.) + + # Format the data for the test + data_df.append(format_tim_win_comp_data( + epochs, + sub, + baseline_window=[-0.2, 0.0], + test_window=[0.3, 0.5])) + + # Convert list to data frame: + data_df= pd.concat(data_df, + axis=0, + ignore_index=True) + + # Performing the moving window test + if band == "gamma": + alternative = "greater" + elif band == "alpha": + alternative = "less" + test_results = moving_window_test( + data_df, + onset=[0.3, 0.5][0], + alternative=alternative) + + # Append results to data frame + test_results["band"] = band + test_results["analysis"] = analysis + results = results.append(test_results) + + # Save results as .tsv + bids_path_source = bids_path_source.copy().update( + root=source_deriv_root, + subject=f"groupphase{phase}", + suffix="onset_offset_results", + check=False) + results.to_csv(bids_path_source.fpath, + sep="\t", + index=False) + # # Plot + # fig, axs = plt.subplots(2,1, figsize=(8,6)) + # for i, label in enumerate(['gnw_all', 'iit_all']): + # data = data_df.loc[data_df['channel'] == label, 'values'] + # axs[i].plot(np.mean(data)) + # axs[i].axhline(0, color='k', linestyle='--') + # # axs[i].set_ylim([-.01,.01]) + # axs[i].set_xlim([0,201]) + # plt.suptitle(f"{band}-{analysis}", fontsize='xx-large', fontweight='bold') + # plt.tight_layout() + + +def format_tim_win_comp_data(epochs, subject, baseline_window, test_window): + """ + This function formats data to compare the activation between different time windows. It will reformat the epochs + into data frames cropped into the specified time windows and take the subtraction between the two passed time + windows. One can then test whether that difference is above chance for an extended period of time + :param epochs: (mne epochs object) contains the data to compute the difference + :param subject: (string) name of the subject + :param baseline_window: (list of two floats) contains the onset and offset of the baseline + :param test_window: (list of two floats) contains the onset and offset of the test data + :return: + """ + print("=" * 40) + print("Welcome to format_cluster_based_data") + data_df = pd.DataFrame() + # Compute baseline and onset: + baseline_data = epochs.copy().crop(tmin=baseline_window[0], + tmax=baseline_window[1]) + onset_data = epochs.copy().crop(tmin=test_window[0], + tmax=test_window[1]) + # Looping through each channel to compute the difference between the two: + for channel in baseline_data.ch_names: + bs = np.squeeze(baseline_data.get_data(picks=channel)) + ons = np.squeeze(onset_data.get_data(picks=channel)) + # It can be that because of rounding the two arrays are not the same size, in which case, equating size + # by taking the smallest + if bs.shape[0] != ons.shape[0]: + min_len = min([bs.shape[0], ons.shape[0]]) + bs = bs[:, 0:min_len] + ons = ons[:, 0:min_len] + diff = ons - bs + # Add to the data_df frame: + data_df = data_df.append(pd.DataFrame( + {"subject": subject, + "channel": channel, + "values": [diff] + } + )) + return data_df + + +def test_sustained_threshold(y, stat_test="t-test", threshold=0.05, + window_sec=0.02, sr=1000, + alternative="greater", fdr_method=None): + """ + :param y: + :param stat_test: + :param threshold: + :param window_sec: + :param sr: + :param alternative: + :param fdr_method: + :return: + """ + # Handling data dimensions + if isinstance(y, np.ndarray): + if len(y.shape) > 2: + raise Exception("You have passed an numpy array of more than 2D! This function only works with 2D numpy " + "array or unnested list!") + elif isinstance(y, list): + if isinstance(y[0], list): + raise Exception("You have passed a nested list! This function only works with 1D numpy " + "array or unnested list!") + elif isinstance(y[0], np.ndarray): + raise Exception("You have passed a list of numpy arrays!This function only works with 1D numpy " + "array or unnested list!") + # Compute the test: + if stat_test == "t-test": + pop_mean = np.zeros(y.shape[0]) + y_stat, y_pval = ttest_1samp(y, pop_mean, axis=1, alternative=alternative) + elif stat_test == "wilcoxon": + y_stat, y_pval = wilcoxon(y, y=None, axis=1, alternative=alternative) + else: + raise Exception("You have passed a test that is not supported!") + # Do fdr correction if needed: + if fdr_method is not None: + y_bin, y_pval, _, _ = multipletests(y_pval, alpha=threshold, method=fdr_method) + else: + y_bin = y_pval < threshold + # Convert the time window from ms to samples: + window_samp = int(window_sec * (sr / 1)) + h0 = True + # Looping through each True in the binarize y: + for ind in np.where(y_bin)[0]: + if ind + window_samp < len(y_bin): + if all(y_bin[ind:ind + window_samp]): + h0 = False + # Finding the offset of the significant window: + onset_samp = ind + if len(np.where(np.diff(y_bin[ind:].astype(int)) == -1)[0]) > 0: + offset_samp = onset_samp + np.where(np.diff(y_bin[ind:].astype(int)) == -1)[0][0] + else: + offset_samp = len(y) - 1 + # Convert to me: + onset_sec, offset_sec = onset_samp * (1 / sr), offset_samp * (1 / sr) + break + else: + break + if h0: + onset_samp, offset_samp = None, None + onset_sec, offset_sec = None, None + return h0, [onset_sec, offset_sec], [onset_samp, offset_samp] + + +def moving_window_test(data_df, onset, groups="channel", alternative="greater"): + print("=" * 40) + print("Welcome to moving_window_test") + # Var to store the results + test_results = pd.DataFrame() + for group in data_df[groups].unique(): + print("Performing test for group: {}".format(group)) + # Get the data of this group + y = data_df.loc[data_df[groups] == group, "values"] + # Convert to array + y = np.array([np.array(yy) for yy in y]) + # Testing the sustained + h0, sig_window_sec, sig_window_samp = test_sustained_threshold(y, alternative=alternative) + + # Create results table: + test_results = test_results.append(pd.DataFrame({ + "channel": group, + "sign": not h0, + "onset": onset + sig_window_sec[0] if sig_window_sec[0] is not None + else None, + "offset": onset + sig_window_sec[1] if sig_window_sec[0] is not None + else None, + }, index=[0])) + + return test_results + + +if __name__ == '__main__': + source_dur_ga() diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..b04505b --- /dev/null +++ b/config/config.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +""" +=========== +Config file +=========== + +Configurate the parameters of the study. + +""" + +import os + +# ============================================================================= +# BIDS SETTINGS +# ============================================================================= +# if os.getlogin() in ['oscfe', 'ferranto', 'FerrantO']: #TODO: doesn't work on the HPC +# bids_root = r'Z:\_bids_' +# else: +bids_root = r'Z:\_bids_' + + +# ============================================================================= +# MAXWELL FILTERING SETTINGS +# ============================================================================= + +# Set filtering method +method='sss' +if method == 'tsss': + st_duration = 10 +else: + st_duration = None + + +# ============================================================================= +# FILTERING AND DOWNSAMPLING SETTINGS +# ============================================================================= + +# Filter and resampling params +l_freq = 1 +h_freq = 40 +sfreq = 100 + + +# ============================================================================= +# EPOCHING SETTINGS +# ============================================================================= + +# Set timewindow +tmin = -1 +tmax = 2.5 + +# Epoch rejection criteria +reject_meg_eeg = dict(grad=4000e-13, # T / m (gradiometers) + mag=6e-12 # T (magnetometers) + #eeg=200e-6 # V (EEG channels) + ) +reject_meg = dict(grad=4000e-13, # T / m (gradiometers) + mag=6e-12 # T (magnetometers) + ) + + +# ============================================================================= +# ICA SETTINGS +# ============================================================================= + +ica_method = 'fastica' +n_components = 0.99 +max_iter = 800 +random_state = 1688 + + +# ============================================================================= +# FACTOR AND CONDITIONS OF INTEREST +# ============================================================================= + +# factor = ['Category'] +# conditions = ['face', 'object', 'letter', 'false'] + +# factor = ['Duration'] +# conditions = ['500ms', '1000ms', '1500ms'] + +# factor = ['Task_relevance'] +# conditions = ['Relevant_target','Relevant_non-target','Irrelevant'] + +# factor = ['Duration', 'Task_relevance'] +# conditions = [['500ms', '1000ms', '1500ms'], +# ['Relevant target','Relevant non-target','Irrelevant']] + +factor = ['Category', 'Task_relevance'] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant target','Relevant non-target','Irrelevant']] + + +# ============================================================================= +# TIME-FREQUENCY REPRESENTATION SETTINGS +# ============================================================================= + +baseline_w = [-0.5, -0.25] #only for plotting +freq_band = 'both' #can be 'low', 'high' or 'both' + + +# ============================================================================= +# SOURCE MODELING +# ============================================================================= + +# Forward model +spacing='oct6' #from forward_model + +# Inverse model +# Beamforming +beam_method = 'dics' #'lcmv' or 'dics' + +active_win = (0.75, 1.25) +baseline_win = (-.5, 0) + + +# ============================================================================= +# PLOTTING +# ============================================================================= + +# Subset of posterior sensors +post_sens = {'grad': ['MEG1932', 'MEG1933', 'MEG2122', 'MEG2123', + 'MEG2332', 'MEG2333', 'MEG1922', 'MEG1923', + 'MEG2112', 'MEG2113', 'MEG2342', 'MEG2343'], + 'mag': ['MEG1931', 'MEG2121', + 'MEG2331', 'MEG1921', + 'MEG2111', 'MEG2341'], + 'eeg': ['EEG056', 'EEG030', + 'EEG057', 'EEG018', + 'EEG032', 'EEG019']} diff --git a/config/iit_gnw_rois.json b/config/iit_gnw_rois.json new file mode 100644 index 0000000..84b0446 --- /dev/null +++ b/config/iit_gnw_rois.json @@ -0,0 +1,199 @@ +{ + "volume_labels": { + "iit_1": [ + "ctx_lh_G_temporal_inf", + "ctx_rh_G_temporal_inf", + "ctx_lh_Pole_temporal", + "ctx_rh_Pole_temporal", + "ctx_lh_G_cuneus", + "ctx_rh_G_cuneus", + "ctx_lh_G_occipital_sup", + "ctx_rh_G_occipital_sup", + "ctx_lh_G_oc-temp_med-Lingual", + "ctx_rh_G_oc-temp_med-Lingual", + "ctx_lh_Pole_occipital", + "ctx_rh_Pole_occipital", + "ctx_lh_G_oc-temp_med-Lingual", + "ctx_rh_G_oc-temp_med-Lingual", + "ctx_lh_S_calcarine", + "ctx_rh_S_calcarine", + "ctx_lh_G_and_S_occipital_inf", + "ctx_rh_G_and_S_occipital_inf", + "ctx_lh_G_occipital_middle", + "ctx_rh_G_occipital_middle", + "ctx_lh_G_oc-temp_lat-fusifor", + "ctx_rh_G_oc-temp_lat-fusifor", + "ctx_lh_G_oc-temp_med-Parahip", + "ctx_rh_G_oc-temp_med-Parahip", + "ctx_lh_S_intrapariet_and_P_trans", + "ctx_rh_S_intrapariet_and_P_trans", + "ctx_lh_G_oc-temp_med-Parahip", + "ctx_rh_G_oc-temp_med-Parahip", + "ctx_lh_S_oc_middle_and_Lunatus", + "ctx_rh_S_oc_middle_and_Lunatus", + "ctx_lh_S_oc_sup_and_transversal", + "ctx_rh_S_oc_sup_and_transversal", + "ctx_lh_S_temporal_sup", + "ctx_rh_S_temporal_sup" + ], + "iit_2": [ + "ctx_lh_G_temporal_inf", + "ctx_rh_G_temporal_inf", + "ctx_lh_Pole_temporal", + "ctx_rh_Pole_temporal", + "ctx_lh_G_cuneus", + "ctx_rh_G_cuneus", + "ctx_lh_G_occipital_sup", + "ctx_rh_G_occipital_sup", + "ctx_lh_G_oc-temp_med-Lingual", + "ctx_rh_G_oc-temp_med-Lingual", + "ctx_lh_Pole_occipital", + "ctx_rh_Pole_occipital", + "ctx_lh_G_oc-temp_med-Lingual", + "ctx_rh_G_oc-temp_med-Lingual", + "ctx_lh_S_calcarine", + "ctx_rh_S_calcarine", + "ctx_lh_G_and_S_occipital_inf", + "ctx_rh_G_and_S_occipital_inf", + "ctx_lh_G_occipital_middle", + "ctx_rh_G_occipital_middle", + "ctx_lh_G_oc-temp_lat-fusifor", + "ctx_rh_G_oc-temp_lat-fusifor", + "ctx_lh_G_oc-temp_med-Parahip", + "ctx_rh_G_oc-temp_med-Parahip", + "ctx_lh_S_intrapariet_and_P_trans", + "ctx_rh_S_intrapariet_and_P_trans", + "ctx_lh_G_oc-temp_med-Parahip", + "ctx_rh_G_oc-temp_med-Parahip", + "ctx_lh_S_oc_middle_and_Lunatus", + "ctx_rh_S_oc_middle_and_Lunatus", + "ctx_lh_S_oc_sup_and_transversal", + "ctx_rh_S_oc_sup_and_transversal", + "ctx_lh_S_temporal_sup", + "ctx_rh_S_temporal_sup", + "ctx_lh_G_precentral", + "ctx_rh_G_precentral", + "ctx_lh_G_temp_sup-Lateral", + "ctx_rh_G_temp_sup-Lateral", + "ctx_lh_G_temp_sup-Plan_tempo", + "ctx_rh_G_temp_sup-Plan_tempo", + "ctx_lh_S_front_inf", + "ctx_rh_S_front_inf", + "ctx_lh_G_pariet_inf-Supramar", + "ctx_rh_G_pariet_inf-Supramar", + "ctx_lh_G_temporal_middle", + "ctx_rh_G_temporal_middle", + "ctx_lh_S_temporal_inf", + "ctx_rh_S_temporal_inf", + "ctx_lh_G_orbital", + "ctx_rh_G_orbital", + "ctx_lh_G_pariet_inf-Angular", + "ctx_rh_G_pariet_inf-Angular", + "ctx_lh_S_interm_prim-Jensen", + "ctx_rh_S_interm_prim-Jensen", + "ctx_lh_S_occipital_ant", + "ctx_rh_S_occipital_ant", + "ctx_lh_S_oc-temp_lat", + "ctx_rh_S_oc-temp_lat", + "ctx_lh_S_precentral-inf-part", + "ctx_rh_S_precentral-inf-part" + ], + "gnw": [ + "ctx_lh_G_and_S_cingul-Ant", + "ctx_rh_G_and_S_cingul-Ant", + "ctx_lh_G_and_S_cingul-Mid-Ant", + "ctx_rh_G_and_S_cingul-Mid-Ant", + "ctx_lh_G_and_S_cingul-Mid-Post", + "ctx_rh_G_and_S_cingul-Mid-Post", + "ctx_lh_G_front_inf-Opercular", + "ctx_rh_G_front_inf-Opercular", + "ctx_lh_G_front_inf-Orbital", + "ctx_rh_G_front_inf-Orbital", + "ctx_lh_G_front_inf-Triangul", + "ctx_rh_G_front_inf-Triangul", + "ctx_lh_G_front_middle", + "ctx_rh_G_front_middle", + "ctx_lh_Lat_Fis-ant-Horizont", + "ctx_rh_Lat_Fis-ant-Horizont", + "ctx_lh_Lat_Fis-ant-Vertical", + "ctx_rh_Lat_Fis-ant-Vertical", + "ctx_lh_S_front_inf", + "ctx_rh_S_front_inf", + "ctx_lh_S_front_middle", + "ctx_rh_S_front_middle", + "ctx_lh_S_front_sup", + "ctx_rh_S_front_sup" + ] + }, + "surf_labels": { + "iit_1": [ + "G_temporal_inf", + "Pole_temporal", + "G_cuneus", + "G_occipital_sup", + "G_oc-temp_med-Lingual", + "Pole_occipital", + "S_calcarine", + "G&S_occipital_inf", + "G_occipital_middle", + "G_oc-temp_lat-fusifor", + "G_oc-temp_med-Parahip", + "S_intrapariet&P_trans", + "S_oc_middle&Lunatus", + "S_oc_sup&transversal", + "S_temporal_sup" + ], + "iit_2": [ + "G_temporal_inf", + "Pole_temporal", + "G_cuneus", + "G_occipital_sup", + "G_oc-temp_med-Lingual", + "Pole_occipital", + "G_oc-temp_med-Lingual", + "S_calcarine", + "G&S_occipital_inf", + "G_occipital_middle", + "G_oc-temp_lat-fusifor", + "G_oc-temp_med-Parahip", + "S_intrapariet&P_trans", + "G_oc-temp_med-Parahip", + "S_oc_middle&Lunatus", + "S_oc_sup&transversal", + "S_temporal_sup", + "G_precentral", + "G_temp_sup-Lateral", + "G_temp_sup-Plan_tempo", + "S_front_inf", + "G_pariet_inf-Supramar", + "G_temporal_middle", + "S_temporal_inf", + "G_orbital", + "G_pariet_inf-Angular", + "S_interm_prim-Jensen", + "S_occipital_ant", + "S_oc-temp_lat", + "S_precentral-inf-part" + ], + "iit_wang": [ + "V1d", + "V1v", + "V2d", + "V2v" + ], + "gnw": [ + "G&S_cingul-Ant", + "G&S_cingul-Mid-Ant", + "G&S_cingul-Mid-Post", + "G_front_inf-Opercular", + "G_front_inf-Orbital", + "G_front_inf-Triangul", + "G_front_middle", + "Lat_Fis-ant-Horizont", + "Lat_Fis-ant-Vertical", + "S_front_inf", + "S_front_middle", + "S_front_sup" + ] + } +} diff --git a/config/participants_MEG_phase2_included.txt b/config/participants_MEG_phase2_included.txt new file mode 100644 index 0000000..bdab5fc --- /dev/null +++ b/config/participants_MEG_phase2_included.txt @@ -0,0 +1,32 @@ +SA106 +SA107 +SA109 +SA110 +SA112 +SA113 +SA116 +SA124 +SA126 +SA127 +SA128 +SA131 +SA142 +SA152 +SA160 +SA172 +SB002 +SB013 +SB015 +SB022 +SB024 +SB030 +SB038 +SB041 +SB045 +SB050 +SB061 +SB065 +SB071 +SB073 +SB078 +SB085 \ No newline at end of file diff --git a/config/participants_MEG_phase3_included.txt b/config/participants_MEG_phase3_included.txt new file mode 100644 index 0000000..b82ee3c --- /dev/null +++ b/config/participants_MEG_phase3_included.txt @@ -0,0 +1,65 @@ +SA102 +SA103 +SA104 +SA111 +SA114 +SA118 +SA121 +SA123 +SA125 +SA132 +SA133 +SA134 +SA136 +SA138 +SA139 +SA140 +SA144 +SA145 +SA146 +SA147 +SA148 +SA150 +SA151 +SA154 +SA158 +SA163 +SA166 +SA167 +SA169 +SA170 +SA173 +SA174 +SA176 +SB001 +SB003 +SB006 +SB008 +SB009 +SB011 +SB012 +SB016 +SB019 +SB020 +SB023 +SB027 +SB028 +SB029 +SB031 +SB035 +SB036 +SB039 +SB040 +SB042 +SB044 +SB049 +SB051 +SB056 +SB060 +SB063 +SB069 +SB072 +SB074 +SB081 +SB084 +SB999 \ No newline at end of file diff --git a/connectivity/Co01_connect_ppc.py b/connectivity/Co01_connect_ppc.py new file mode 100644 index 0000000..1bd7d58 --- /dev/null +++ b/connectivity/Co01_connect_ppc.py @@ -0,0 +1,547 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co01. Connectivity +=================================== + +Compute coherence in source space using a MNE inverse solution + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import numpy as np +import os +import os.path as op +import matplotlib.pyplot as plt +import argparse +import seaborn as sns +import json +import statsmodels.api as sm + +import mne +from mne.minimum_norm import (make_inverse_operator, apply_inverse_epochs, + # write_inverse_operator + ) +from mne_connectivity import spectral_connectivity_epochs #spectral_connectivity +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--task_rel', + type=str, + default='irr', + help='specify the task condition ("irr", "rel" or "comb")') +parser.add_argument('--remove_evoked', + type=str, + default='false', + help='Remove evoked response? (true or false)') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +con_method = 'ppc' +durs = ["1000ms", "1500ms"] +task_rel = opt.task_rel +remove_evoked = opt.remove_evoked.lower() == 'true' + +surrogate = False +use_long_ged = False +surrogate = False + +debug = False + +# Define vars for output folder name +if task_rel == "comb": + tasks = ["Relevant non-target", "Irrelevant"] + t = "" +elif task_rel == "irr": + tasks = ["Irrelevant"] + t = "_irr" +elif task_rel == "rel": + tasks = ["Relevant non-target"] + t = "_rel" +else: + raise ValueError(f"Error: tasks={tasks} not valid") + +if len(durs) == 3: + d = "_all-durs" +else: + d = "" + +if remove_evoked: + e = "_no-evoked" +else: + e = "" + +if surrogate: + s = "_surrogate" +else: + s = "" + +if use_long_ged: + g = "_0.0-2.0" + ged_label_list = ['fusiform'] +else: + g = "" + ged_label_list = ['fusifor'] + + +def connectivity(subject_id, visit_id): + # Set path to preprocessing derivatives and create the related folders + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + rois_deriv_root = op.join(bids_root, "derivatives", "roilabel") + ged_deriv_root = op.join(bids_root, "derivatives", "ged") + + con_deriv_root = op.join(bids_root, "derivatives", "connectivity"+t, d, g, e, s) + if not op.exists(con_deriv_root): + os.makedirs(con_deriv_root) + con_figure_root = op.join(con_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures", + con_method) + if not op.exists(con_figure_root): + os.makedirs(con_figure_root) + + print("Processing subject: %s" % subject_id) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick trials + epochs = epochs[f'Task_relevance in {tasks} and Duration in {durs}'] + if debug: + epochs = epochs[0:100] + + # Select sensor + epochs.load_data().pick('meg') #try with EEG as well + + # Get sampling frequency + sfreq = epochs.info['sfreq'] + + # Baseline correction + b_tmin = -.5 + b_tmax = -.0 + baseline = (b_tmin, b_tmax) + epochs.apply_baseline(baseline=baseline) + + # Crop epochs time window to save memory + e_tmin= -.5 + e_tmax= 2. + epochs.crop(e_tmin, e_tmax) + + + ## LABELS + + # Read labels from FS parc + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + labels_Datlas = mne.read_labels_from_annot( #Destrieux's atlas + "fsaverage", + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + labels_Watlas = mne.read_labels_from_annot( #Wang's atlas + "fsaverage", + parc='wang2015_mplbl', + subjects_dir=fs_deriv_root) + else: + labels_Datlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + labels_Watlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='wang2015_mplbl', + subjects_dir=fs_deriv_root) + + # labels_Datlas_names = [l.name for l in labels_Datlas] + # labels_Watlas_names = [l.name for l in labels_Watlas] + + # Read GNW and IIT ROI list + f = open(op.join(rois_deriv_root, + 'iit_gnw_rois.json')) + gnw_iit_rois = json.load(f) + + # Create labels for selected ROIs + labels = {} + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + for lab in gnw_iit_rois['surf_labels']['iit_wang']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["iit_"+lab+"_lh"] = [l for l in labels_Watlas if l.name == lab+"-lh"] + labels["iit_"+lab+"_rh"] = [l for l in labels_Watlas if l.name == lab+"-rh"] + + for lab in gnw_iit_rois['surf_labels']['gnw']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_Datlas if l.name == lab+"-lh"] + labels["gnw_"+lab+"_rh"] = [l for l in labels_Datlas if l.name == lab+"-rh"] + else: + for lab in gnw_iit_rois['surf_labels']['iit_wang']: + print(lab) + labels["iit_"+lab+"_lh"] = [l for l in labels_Watlas if l.name == lab+"-lh"] + labels["iit_"+lab+"_rh"] = [l for l in labels_Watlas if l.name == lab+"-rh"] + + for lab in gnw_iit_rois['surf_labels']['gnw']: + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_Datlas if l.name == lab+"-lh"][0] + labels["gnw_"+lab+"_rh"] = [l for l in labels_Datlas if l.name == lab+"-rh"][0] + + # # Save labels + # bids_path_con = bids_path_epo.copy().update( + # root=con_deriv_root, + # suffix="labels", + # extension='.pkl', + # check=False) + + # with open(bids_path_con.fpath, 'wb') as outp: + # pickle.dump(labels, outp, pickle.HIGHEST_PROTOCOL) + + # Get V1/V2 labels and sum + iit_v1v2_label = np.sum([labels["iit_V1d_lh"], + labels["iit_V1d_rh"], + labels["iit_V1v_lh"], + labels["iit_V1v_rh"], + labels["iit_V2d_lh"], + labels["iit_V2d_rh"], + labels["iit_V2v_lh"], + labels["iit_V2v_rh"]]) + + + ## Category-selective GED + + # Set params + ged_label_name = ''.join(ged_label_list) + + # Create label + ged_labels = [] + # Loop over labels + for regexp in ged_label_list: + + # Create label for the given region + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + lab = mne.read_labels_from_annot( + "fsaverage", + parc='aparc', #aparc aparc.a2009s + regexp=regexp, #'inferiortemporal' + hemi='both', + subjects_dir=fs_deriv_root) + else: + lab = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc', #aparc aparc.a2009s + regexp=regexp, #'inferiortemporal' + hemi='both', + subjects_dir=fs_deriv_root) + + # Append to GED labels + ged_labels.append(lab) + + # Combine GED labels + ged_labels = np.sum(ged_labels) + + # Read GED filter + bids_path_ged = bids_path_epo.copy().update( + root=op.join(ged_deriv_root,g), + suffix=f'desc-{ged_label_name},face_evecs', + extension='.npy', + check=False) + ged_face_evecs = np.load(bids_path_ged.fpath) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{ged_label_name},object_evecs') + ged_object_evecs = np.load(bids_path_ged.fpath) + + + ## GNW prefrontal GED + + # Merge all labels in prefrontal GNW ROI + ged_gnw_label = np.sum([l for l_name, l in labels.items() if 'gnw' in l_name]) + + # Read GNW prefrontal GED filter + bids_path_ged = bids_path_epo.copy().update( + root=ged_deriv_root, + suffix='desc-gnw_evecs', + extension='.npy', + check=False) + ged_gnw_evecs = np.load(bids_path_ged.fpath) + + + ## SOURCE MODELLING + + # Compute rank + rank = mne.compute_rank(epochs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = bids_path_epo.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + + active_cov = mne.compute_covariance(epochs, + tmin=0, + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Make inverse operator + inverse_operator = make_inverse_operator( + epochs.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + # # Save inverse operator + # bids_path_inv = bids_path_con.copy().update( + # suffix="inv_c%s" % condition, + # extension='.fif', + # check=False) + # write_inverse_operator(bids_path_inv.fpath, + # inverse_operator) + + + ## CONNECTIVITY + if task_rel == "comb": + n_cond = 4 + else: + n_cond = 2 + + for condition in range(1,n_cond+1): + + # Pick condition + if condition == 1: + epochs_cond = epochs['Category == "object"'].copy() + cond_name = "object" + elif condition == 2: + epochs_cond = epochs['Category == "face"'].copy() + cond_name = "face" + elif condition == 3: + epochs_cond = epochs['Task_relevance == "Relevant non-target"'].copy() + cond_name = "relev" + elif condition == 4: + epochs_cond = epochs['Task_relevance == "Irrelevant"'].copy() + cond_name = "irrel" + else: + raise ValueError("Condition %s does not exists" % condition) + print("\n\n\n### Running condition " + cond_name + " ###\n\n") + + # Compute inverse solution for each epoch + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + method = "dSPM" + + stcs = apply_inverse_epochs(epochs_cond, + inverse_operator, + lambda2, + method, + pick_ori="normal", + return_generator=False) + del epochs_cond + + # Average source estimates within each label to reduce signal cancellations + src = inverse_operator['src'] + iit_label_ts = mne.extract_label_time_course( + stcs, iit_v1v2_label, src, + mode='pca_flip', + return_generator=False) + + # Apply GED filter to source-level epochs + ged_face_ts = [] + ged_object_ts = [] + ged_gnw_ts = [] + for i in range(len(stcs)): + # Get data + data = stcs[i].in_label(ged_labels).data + data_gnw = stcs[i].in_label(ged_gnw_label).data + # Apply GED filter + ged_face_ts.append(ged_face_evecs[:,0].T @ data) + ged_object_ts.append(ged_object_evecs[:,0].T @ data) + ged_gnw_ts.append(ged_gnw_evecs[:,0].T @ data_gnw) + + del stcs + + # # Save GED time series + # bids_path_ged = bids_path_ged.copy().update( + # root=op.join(ged_deriv_root,g), + # suffix=f'desc-{ged_label_name},face_ts', + # extension='.npy', + # check=False) + # np.save(bids_path_ged.fpath, ged_face_ts) + + # bids_path_ged = bids_path_ged.copy().update( + # suffix=f'desc-{ged_label_name},object_ts') + # np.save(bids_path_ged.fpath, ged_object_ts) + + # bids_path_ged = bids_path_ged.copy().update( + # suffix='desc-gnw_ts') + # np.save(bids_path_ged.fpath, ged_gnw_ts) + + # Concatenate GNW & IIT labels and GED spatial filters + all_ts = [] + for i in range(len(ged_face_ts)): + all_ts.append(np.vstack([ged_gnw_ts[i], iit_label_ts[i], ged_face_ts[i], ged_object_ts[i]])) + ged_filter_labels = ['pfc','v1v2','face filter','object filter'] + + # Create indices of label-to-label couples for which to compute connectivity + n_labels = 2 + indices = (np.concatenate([range(0,n_labels),range(0,n_labels)]), + np.array([n_labels]*len(range(0,n_labels)) + [n_labels+1]*len(range(0,n_labels)))) + + # Create surrogate data by shuffling trial labels + if surrogate: + # Convert list to array + all_ts_array = np.array(all_ts) + + # Loop over nodes + for n in range(len(all_ts_array[0])): + + # Get trial number indices + ind = np.arange(len(all_ts_array)) + + # Shuffle trial indeces + np.random.shuffle(ind) + # plt.plot(ind) + + # Shuffle trials in the node data + all_ts_array[:,n,:] = all_ts_array[ind,n,:] + + # Convert array back to list + all_ts = [all_ts_array[i,:,:] for i in range(len(all_ts_array))] + + # Remove evoked using regression + if remove_evoked: + all_evoked = np.mean(all_ts, axis=0) + for node in range(len(all_ts[0])): + node_evoked = all_evoked[node,:] + for trial in range(len(all_ts)): + all_ts[trial][node,:] = sm.OLS(np.array(all_ts)[trial,node,:], node_evoked).fit().resid + + # Run connectivity separatelly for low and high frequencies + for freq_range in ['low', 'high']: + print('\nComputing connectivity in', freq_range, 'frequency range...') + + # Set connectivity params + mode = 'cwt_morlet' + if freq_range == 'low': + fmin = 2. + fmax = 30. + fstep = 1. + cwt_freqs = np.arange(fmin, fmax, fstep) + cwt_n_cycles = 4 + elif freq_range == 'high': + fmin = 30. + fmax = 101. + fstep = 2. + cwt_freqs = np.arange(fmin, fmax, fstep) + cwt_n_cycles = cwt_freqs / 4. + + # Run connectivity + con = spectral_connectivity_epochs( + all_ts, + method=con_method, + indices=indices, + mode=mode, + sfreq=sfreq, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles + ) + + # Save connectivity results + bids_path_con = bids_path_epo.copy().update( + root=con_deriv_root, + suffix=f"desc-gnw-pfc-ged,{con_method},{freq_range},{cond_name}_con", + extension='.nc', + check=False) + + con.save(bids_path_con.fpath) + + # Plot results (con_data = time x label1 x label2 x freq) + times = ['%.0f' %t for t in (np.array(con.times) - .5) * 1000] + freqs = ['%.0f' %f for f in con.freqs] + indices_comb = [[i,j] for i,j in zip(indices[0], indices[1])] + + for i in indices_comb: + fig, ax = plt.subplots() + sns.heatmap(con.get_data()[indices_comb.index(i),:,:], + xticklabels=250, yticklabels=5, + # vmin=0, vmax=.4, + cmap='RdYlBu_r', + ax=ax) + ax.set_xticklabels(times[0::250], + fontsize=8) + ax.invert_yaxis() + ax.set_yticklabels(freqs[0::5], rotation='horizontal', + fontsize=8) + + # ax.set_xticklabels(np.rint((tmins[1:]-(twin/2))*1000).astype(int)) + plt.xlabel("time (ms)", fontsize=14) + # ax.invert_yaxis() + # ax.set_yticklabels(freqs, rotation='horizontal') + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"Connect b/w {ged_filter_labels[i[0]]} & {ged_filter_labels[i[1]]}", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn-gnw-pfc-ged_{con_method}_{freq_range}_{cond_name}_{ged_filter_labels[i[0]]}-x-{ged_filter_labels[i[1]]}.png") + fig.savefig(fname_fig) + plt.close(fig) + + +if __name__ == '__main__': + # subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + # visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + connectivity(subject_id, visit_id) diff --git a/connectivity/Co01c_connect_dfc.py b/connectivity/Co01c_connect_dfc.py new file mode 100644 index 0000000..052a585 --- /dev/null +++ b/connectivity/Co01c_connect_dfc.py @@ -0,0 +1,571 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co01. Connectivity DFC +=================================== + +Compute guassion-copula mutal information in source space using a MNE inverse solution + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import numpy as np +import os +import os.path as op +import matplotlib.pyplot as plt +import argparse +import json +import statsmodels.api as sm +import xarray as xr +from scipy import stats + +import mne +from mne.minimum_norm import (make_inverse_operator, apply_inverse_epochs, + # write_inverse_operator + ) +import mne_bids + +from frites.conn import conn_dfc, define_windows + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--method', + type=str, + default='dfc', + help='method used to measure connectivity (e.g. "coh")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +# parser.add_argument('--out_con', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', +# help='Path to the connectivity (derivative) directory') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +con_method = opt.method +ged_label_list = ['fusifor'] + +task_rel = ["Irrelevant"] + +surrogate = False +remove_evoked = True + +debug = False + +# Define vars for output folder name +if task_rel == ["Relevant non-target", "Irrelevant"]: + t = "" +elif task_rel == ["Irrelevant"]: + t = "_irr" +elif task_rel == ["Relevant non-target"]: + t = "_rel" +if remove_evoked: + e = "_no-evoked" +else: + e = "" +if surrogate: + s = "_surrogate" +else: + s = "" + + +def connectivity_dfc(subject_id, visit_id): + # Set path to preprocessing derivatives and create the related folders + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + fwd_deriv_root = op.join(bids_root, "derivatives", "forward") + fs_deriv_root = op.join(bids_root, "derivatives", "fs") + rois_deriv_root = op.join(bids_root, "derivatives", "roilabel") + ged_deriv_root = op.join(bids_root, "derivatives", "ged") + + con_deriv_root = op.join(bids_root, "derivatives", "connectivity"+t, "_dfc", e, s) + if not op.exists(con_deriv_root): + os.makedirs(con_deriv_root) + con_figure_root = op.join(con_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures", + con_method) + if not op.exists(con_figure_root): + os.makedirs(con_figure_root) + + print("Processing subject: %s" % subject_id) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick trials + epochs = epochs[f'Task_relevance in {task_rel} and Duration != "500ms"'] + if debug: + epochs = epochs[0:100] + + # Select sensor + epochs.load_data().pick('meg') #try with EEG as well + + # Baseline correction + b_tmin = -.5 + b_tmax = -.0 + baseline = (b_tmin, b_tmax) + epochs.apply_baseline(baseline=baseline) + + # Crop epochs time window to save memory + e_tmin= -.5 + e_tmax= 2. + epochs.crop(e_tmin, e_tmax) + + + ## LABELS + + # Read labels from FS parc + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + labels_Datlas = mne.read_labels_from_annot( #Destrieux's atlas + "fsaverage", + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + labels_Watlas = mne.read_labels_from_annot( #Wang's atlas + "fsaverage", + parc='wang2015_mplbl', + subjects_dir=fs_deriv_root) + else: + labels_Datlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + labels_Watlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='wang2015_mplbl', + subjects_dir=fs_deriv_root) + + # labels_Datlas_names = [l.name for l in labels_Datlas] + # labels_Watlas_names = [l.name for l in labels_Watlas] + + # Read GNW and IIT ROI list + f = open(op.join(rois_deriv_root, + 'iit_gnw_rois.json')) + gnw_iit_rois = json.load(f) + + # Create labels for selected ROIs + labels = {} + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + for lab in gnw_iit_rois['surf_labels']['iit_wang']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["iit_"+lab+"_lh"] = [l for l in labels_Watlas if l.name == lab+"-lh"] + labels["iit_"+lab+"_rh"] = [l for l in labels_Watlas if l.name == lab+"-rh"] + + for lab in gnw_iit_rois['surf_labels']['gnw']: + lab = lab.replace('&','_and_') # Fix the label name to match the template one + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_Datlas if l.name == lab+"-lh"] + labels["gnw_"+lab+"_rh"] = [l for l in labels_Datlas if l.name == lab+"-rh"] + else: + for lab in gnw_iit_rois['surf_labels']['iit_wang']: + print(lab) + labels["iit_"+lab+"_lh"] = [l for l in labels_Watlas if l.name == lab+"-lh"] + labels["iit_"+lab+"_rh"] = [l for l in labels_Watlas if l.name == lab+"-rh"] + + for lab in gnw_iit_rois['surf_labels']['gnw']: + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_Datlas if l.name == lab+"-lh"][0] + labels["gnw_"+lab+"_rh"] = [l for l in labels_Datlas if l.name == lab+"-rh"][0] + + # # Save labels + # bids_path_con = bids_path_epo.copy().update( + # root=con_deriv_root, + # suffix="labels", + # extension='.pkl', + # check=False) + + # with open(bids_path_con.fpath, 'wb') as outp: + # pickle.dump(labels, outp, pickle.HIGHEST_PROTOCOL) + + # Get V1/V2 labels and sum + iit_v1v2_label = np.sum([labels["iit_V1d_lh"], + labels["iit_V1d_rh"], + labels["iit_V1v_lh"], + labels["iit_V1v_rh"], + labels["iit_V2d_lh"], + labels["iit_V2d_rh"], + labels["iit_V2v_lh"], + labels["iit_V2v_rh"]]) + + + ## Category-selective GED + + # Set params + ged_label_name = ''.join(ged_label_list) + + # Create label + ged_labels = [] + # Loop over labels + for regexp in ged_label_list: + + # Create label for the given region + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + lab = mne.read_labels_from_annot( + "fsaverage", + parc='aparc', + regexp=regexp, + hemi='both', + subjects_dir=fs_deriv_root) + else: + lab = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc', + regexp=regexp, + hemi='both', + subjects_dir=fs_deriv_root) + + # # Save label + # bids_path_ged = mne_bids.BIDSPath( + # root=ged_deriv_root, + # subject=subject_id, + # datatype='meg', + # task=None, + # session=visit_id, + # suffix=f"desc-{regexp}_label-lh", + # extension='.label', + # check=False) + # lab[0].save(bids_path_ged.fpath) + + # bids_path_ged = bids_path_ged.copy().update( + # suffix=f"desc-{regexp}_label-rh",) + # lab[1].save(bids_path_ged.fpath) + + # Append to GED labels + ged_labels.append(lab) + + # Combine GED labels + ged_labels = np.sum(ged_labels) + + # Read GED filter + bids_path_ged = bids_path_epo.copy().update( + root=ged_deriv_root, + suffix=f'desc-{ged_label_name},face_evecs', + extension='.npy', + check=False) + ged_face_evecs = np.load(bids_path_ged.fpath) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{ged_label_name},object_evecs') + ged_object_evecs = np.load(bids_path_ged.fpath) + + + ## GNW prefrontal GED + + # Merge all labels in prefrontal GNW ROI + ged_gnw_label = np.sum([l for l_name, l in labels.items() if 'gnw' in l_name]) + + # Read GNW prefrontal GED filter + bids_path_ged = bids_path_epo.copy().update( + root=ged_deriv_root, + suffix='desc-gnw_evecs', + extension='.npy', + check=False) + ged_gnw_evecs = np.load(bids_path_ged.fpath) + + + ## SOURCE MODELLING + + # Compute rank + rank = mne.compute_rank(epochs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = bids_path_epo.copy().update( + root=fwd_deriv_root, + task=None, + suffix="surface_fwd", + extension='.fif', + check=False) + + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + + active_cov = mne.compute_covariance(epochs, + tmin=0, + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Make inverse operator + inverse_operator = make_inverse_operator( + epochs.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + # # Save inverse operator + # bids_path_inv = bids_path_con.copy().update( + # suffix="inv_c%s" % condition, + # extension='.fif', + # check=False) + # write_inverse_operator(bids_path_inv.fpath, + # inverse_operator) + + + ## CONNECTIVITY + + # Loop over conditions + for condition in range(1,3): + + # Pick condition + if condition == 1: + epochs_cond = epochs['Category == "object"'].copy() + cond_name = "object" + elif condition == 2: + epochs_cond = epochs['Category == "face"'].copy() + cond_name = "face" + else: + raise ValueError("Condition %s does not exists" % condition) + print("\n Running condition " + cond_name + "\n") + + # Compute inverse solution for each epoch + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + method = "dSPM" + + stcs = apply_inverse_epochs(epochs_cond, + inverse_operator, + lambda2, + method, + pick_ori="normal", + return_generator=False) + del epochs_cond + + # Average source estimates within each label to reduce signal cancellations + src = inverse_operator['src'] + iit_label_ts = mne.extract_label_time_course( + stcs, iit_v1v2_label, src, + mode='pca_flip', #was mean_flip + return_generator=False) + + # Apply GED filter to source-level epochs + ged_face_ts = [] + ged_object_ts = [] + ged_gnw_ts = [] + for i in range(len(stcs)): + # Get data + data = stcs[i].in_label(ged_labels).data + data_gnw = stcs[i].in_label(ged_gnw_label).data + # Apply GED filter + ged_face_ts.append(ged_face_evecs[:,0].T @ data) + ged_object_ts.append(ged_object_evecs[:,0].T @ data) + ged_gnw_ts.append(ged_gnw_evecs[:,0].T @ data_gnw) + + del stcs + + # Concatenate GNW & IIT labels and GED spatial filters + all_ts = [] + for i in range(len(ged_face_ts)): + all_ts.append(np.vstack([ged_gnw_ts[i], iit_label_ts[i], ged_face_ts[i], ged_object_ts[i]])) + ged_filter_labels = ['pfc','v1v2','face filter','object filter'] + + # Create surrogate data by shuffling trial labels + if surrogate: + # Convert list to array + all_ts_array = np.array(all_ts) + + # Loop over nodes + for n in range(len(all_ts_array[0])): + + # Get trial number indices + ind = np.arange(len(all_ts_array)) + + # Shuffle trial indeces + np.random.shuffle(ind) + # plt.plot(ind) + + # Shuffle trials in the node data + all_ts_array[:,n,:] = all_ts_array[ind,n,:] + + # Convert array back to list + all_ts = [all_ts_array[i,:,:] for i in range(len(all_ts_array))] + + # Remove evoked using regression + if remove_evoked: + all_evoked = np.mean(all_ts, axis=0) + for node in range(len(all_ts[0])): + node_evoked = all_evoked[node,:] + for trial in range(len(all_ts)): + all_ts[trial][node,:] = sm.OLS(np.array(all_ts)[trial,node,:], node_evoked).fit().resid + + + # Compute Dynamic Functional Connectivity using the Gaussian-Copula Mutual Information (GCMI) + + # Insert data in an epochs object + info = mne.create_info(ged_filter_labels, epochs.info['sfreq'], ch_types='grad') + ep = mne.EpochsArray(all_ts, info) #tmin set to 0 for convinience (real tmin = -500) + + # Create indices of label-to-label couples for which to compute connectivity + n_labels = 2 + indices = (np.concatenate([range(0,n_labels),range(0,n_labels)]), + np.array([n_labels]*len(range(0,n_labels)) + [n_labels+1]*len(range(0,n_labels)))) + + # Set params + times = epochs.times + trials = np.arange(len(all_ts)) + + # Define the sliding windows + window_len = 0.1 #100ms + step = 0.02 #20ms + sl_win = define_windows(times, + slwin_len=window_len, + slwin_step=step)[0] + + # Compute tfr + tfr = mne.time_frequency.tfr_multitaper( + ep, + freqs=np.concatenate( + (np.arange(2,30,1), + np.arange(30,101,2))), + n_cycles=np.concatenate( + (np.tile(4,len(np.arange(2,30,1))), + np.arange(30,101,2)/4)), + use_fft=True, + return_itc=False, + average=False, + time_bandwidth=2., + verbose=True) + + # Create empty array + conndat = np.empty( + (len(indices[0]),len(tfr.freqs),len(sl_win))) + + # Run DFC analysis + for f, freq in enumerate( tfr.freqs ): + for i_, ind_ in enumerate(zip( indices[0], indices[1])): + # Convert data to xarray + x = np.squeeze(tfr.data[:, [ind_[0], ind_[1]], f, :]) + rr = ['r0', 'r1'] + x = xr.DataArray( + x, + dims=('trials', 'space', 'times'), + coords=(trials, rr, times)) + + # Compute DFC on sliding windows + dfc = conn_dfc( + x, + times='times', + roi='space', + win_sample=sl_win) + + conndat[i_, f, :] = dfc.mean('trials').squeeze().data + + # Save results + print('\nSaving...') + bids_path_con = bids_path_epo.copy().update( + root=con_deriv_root, + suffix=f"desc-{con_method}_{cond_name}_con", + extension='.npy', + check=False) + + np.save(bids_path_con.fpath, conndat) + + # Save times and freqs info + bids_path_con = bids_path_epo.copy().update( + root=con_deriv_root, + suffix=f"desc-{con_method}_times", + extension='.npy', + check=False) + + np.save(bids_path_con.fpath, dfc['times'].values) + + bids_path_con = bids_path_epo.copy().update( + root=con_deriv_root, + suffix=f"desc-{con_method}_freqs", + extension='.npy', + check=False) + + np.save(bids_path_con.fpath, tfr.freqs) + + + # Plot + analysis_time = [round(x,3) for x in dfc['times'].values] + freqs = [int(x) for x in tfr.freqs] + extent = list([analysis_time[0],analysis_time[-1],1,len(freqs)]) + + indices_comb = [[i,j] for i,j in zip(indices[0], indices[1])] + + for i in indices_comb: + # Get data and do z-scoring by frequencies + data = stats.zscore(conndat[indices_comb.index(i),:,:], axis=1) + + # Plot + fig, ax = plt.subplots(figsize=[8,6]) + im = ax.imshow(data, + cmap='RdYlBu_r', + extent=extent, + origin="lower", + aspect='auto') + + cbar = plt.colorbar(im, ax=ax) + cbar.ax.tick_params(labelsize=8) + + ax.set_yticklabels(freqs[0::5]) + ax.axhline(freqs.index(30), color='w', lw=4) + + plt.xlabel("Time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"Conn {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}: {cond_name}", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn_{con_method}_{cond_name}_{ged_filter_labels[i[0]]}-x-{ged_filter_labels[i[1]]}.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + +if __name__ == '__main__': + # subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + # visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + connectivity_dfc(subject_id, visit_id) diff --git a/connectivity/Co02_connect_ppc_ga.py b/connectivity/Co02_connect_ppc_ga.py new file mode 100644 index 0000000..668af4b --- /dev/null +++ b/connectivity/Co02_connect_ppc_ga.py @@ -0,0 +1,494 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co02. Grand-average connectivity +=================================== + +Compute the grand average for the connectivity analysis + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import numpy as np +import os +import os.path as op +import matplotlib.pyplot as plt +import argparse +import seaborn as sns +from scipy import stats as stats +import pickle + +import mne +from mne_connectivity import read_connectivity, SpectroTemporalConnectivity +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='ppc', + help='method used to measure connectivity (e.g. "coh")') +opt=parser.parse_args() + + +# Set params +visit_id = "V1" +con_method = opt.method + +task_rel = ["Irrelevant"] +remove_evoked = False + +all_durs = False +use_long_ged = False +surrogate = False + +debug = False + +# Define vars for output folder name +if task_rel == ["Relevant non-target", "Irrelevant"]: + t = "" +elif task_rel == ["Irrelevant"]: + t = "_irr" +elif task_rel == ["Relevant non-target"]: + t = "_rel" +if all_durs: + d = "_all-durs" +else: + d = "" +if remove_evoked: + e = "_no-evoked" +else: + e = "" +if surrogate: + s = "_surrogate" +else: + s = "" +if use_long_ged: + g = "_0.0-2.0" +else: + g = "" + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def connectivity_ga(sub_list, visit_id): + # Set path to preprocessing derivatives and create the related folders + con_deriv_root = op.join(bids_root, "derivatives", "connectivity"+t, d, g, e, s) + if not op.exists(con_deriv_root): + raise ValueError("Error: connectivity derivatives folder does not exist") + con_figure_root = op.join(con_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures", + con_method) + if not op.exists(con_figure_root): + os.makedirs(con_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + print('\nCompute connectivity grandaverage with method:', con_method) + + # Create indices of connections for which connectivity was computed + n_labels = 2 + indices = (np.concatenate([range(0,n_labels),range(0,n_labels)]), + np.array([n_labels]*len(range(0,n_labels)) + [n_labels+1]*len(range(0,n_labels)))) + + # Loop over frequencies + for freq_range in ['low', 'high']: + print(f'\nFreq range: {freq_range}') + + # Loop over analysis (i.e., contrasts) + con_dif = {} + if task_rel == ["Relevant non-target", "Irrelevant"]: + cond_contr = [['face', 'object'],['relev', 'irrel']] + else: + cond_contr = [['face', 'object']] + for anal_contr in cond_contr: + print(f"\nAnalysis: {anal_contr[0]} vs {anal_contr[1]}") + + # Loop over conditions + con_condlist = {} + for cond_name in anal_contr: + print(f"\nCondition: {cond_name}") + + # Load connectivity results + con_all = [] + for sub in sub_list: + print(f"subject id: {sub}") + + # Set path + bids_path_con = mne_bids.BIDSPath( + root=con_deriv_root, + subject=sub, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f"desc-gnw-pfc-ged,{con_method},{freq_range},{cond_name}_con", + extension='.nc', + check=False) + + # Load data + con_all.append( + read_connectivity(bids_path_con.fpath)) + + # Get data + con_all_data = [] + for con in con_all: + con_all_data.append(con.get_data()) + times = ['%.0f' %t for t in (np.array(con.times) - .5) * 1000] + freqs = ['%.0f' %f for f in con.freqs] + + con_all_data = np.asarray(con_all_data) #convert to array + + # Append individual con data to full data list + con_condlist[cond_name] = np.asarray(con_all_data) + + # Average data across participants and put them in a connectivity object + con_ga = SpectroTemporalConnectivity( + data = np.mean(con_all_data, axis=0), + freqs = con.freqs, + times = con.times, + n_nodes = con.n_nodes, + indices = indices) + + # Save grandaverage data + bids_path_con = bids_path_con.copy().update( + subject=f"groupphase{phase}", + check=False) + + con_ga.save(bids_path_con.fpath) + + + ## Plotting + + # Set plotting params + ged_filter_labels = ['pfc','v1v2','face filter','object filter'] + indices_comb = [[i,j] for i,j in zip(indices[0], indices[1])] + vmin = 0. + vmax = .15 + + # Plot individual ROI results + for i in indices_comb: + print(f'\nPlotting {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}...') + fig, ax = plt.subplots(figsize=[8,6]) + sns.heatmap(con_ga.get_data()[indices_comb.index(i),:,:], + xticklabels=250, yticklabels=5, + vmin=vmin, + vmax=vmax, + cmap='RdYlBu_r', + ax=ax) + ax.set_xticklabels(times[0::250], + fontsize=8) + ax.invert_yaxis() + ax.set_yticklabels(freqs[0::5], rotation='horizontal', + fontsize=8) + cbar = ax.collections[0].colorbar + cbar.ax.tick_params(labelsize=8) + + plt.xlabel("time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"{con_method} on {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}: {cond_name}", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn-{con_method}_{freq_range}_{cond_name}_{ged_filter_labels[i[0]]}-x-{ged_filter_labels[i[1]]}.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + ## Face vs. object / relev. vs. irrel.: Permutation analysis + + # Set test params + pval = 0.05 # arbitrary + n_observations = len(sub_list) + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + # Loop over indices + t_obs_all = [] + clusters_all = [] + cluster_p_values_all = [] + p_values_all = [] + for i in indices_comb: + print(f'\nTesting clursters for {ged_filter_labels[i[0]]} - {ged_filter_labels[i[1]]}') + + # Get data (subjects) × time × space + Xfac = con_condlist[f'{anal_contr[0]}'][:,indices_comb.index(i),:,:] + Xobj = con_condlist[f'{anal_contr[1]}'][:,indices_comb.index(i),:,:] + + # Run permutation analysis + t_obs, clusters, cluster_p_values, H0 = \ + mne.stats.permutation_cluster_1samp_test( + Xfac - Xobj, + threshold=thresh, + out_type='mask') + + # Append results to list + t_obs_all.append(t_obs) + clusters_all.append(clusters) + cluster_p_values_all.append(cluster_p_values) + p_values_all.append(cluster_p_values) + + # Select the clusters that are statistically significant at p < 0.05 + good_clusters_all = [] + for clusters, cluster_p_values in zip(clusters_all, cluster_p_values_all): + good_clusters_idx = np.where(cluster_p_values < 0.05)[0] + good_clusters = [clusters[idx] for idx in good_clusters_idx] + good_clusters_all.append(good_clusters) + + # Save significant clusters + bids_path_con = bids_path_con.copy().update( + subject=f"groupphase{phase}", + suffix=f"desc-gnw-pfc-ged,{con_method},{freq_range},{anal_contr}_clusters", + extension=".pkl", + check=False) + + with open(bids_path_con.fpath, 'wb') as file: + pickle.dump(good_clusters_all, file) + + + ## Face vs. object / relev. vs. irrel.: Plotting + + # Compute difference between face and object trials + con_dif[f"{anal_contr}"] = con_condlist[f'{anal_contr[0]}'] - con_condlist[f'{anal_contr[1]}'] + + con_dif_data = np.mean(con_dif[f"{anal_contr}"], axis=0) + + vmin = -.075 + vmax = .075 + + # Plot + for i in indices_comb: + print(f'\nPlotting {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}...') + # Get data + data = con_dif_data[indices_comb.index(i),:,:] + # extent = [0,len(times),0,len(freqs)] + extent = list(map(int, [times[0],times[-1],freqs[0],freqs[-1]])) + sig_mask = np.any(good_clusters_all[indices_comb.index(i)], axis=0) + masked_data = np.ma.masked_where(sig_mask == 0, data) + + # Open figure + fig, ax = plt.subplots(figsize=[8,6]) + + # Plot all data + ax.imshow(data, + cmap='RdYlBu_r', + extent=extent, + origin="lower", + alpha=.4, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Plot masked data + im = ax.imshow(masked_data, + cmap='RdYlBu_r', + origin='lower', + extent=extent, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Draw contour + if np.any(sig_mask == 1): + ax.contour(sig_mask, + levels=[0, 1], + colors="k", + origin="lower", + extent=extent) + + ax.set_yticklabels(freqs[0::5]) + + cbar = plt.colorbar(im, ax=ax) + cbar.ax.tick_params(labelsize=8) + + plt.xlabel("time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"{con_method} on {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}: {anal_contr[0]} vs {anal_contr[1]}", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn-{con_method}_{freq_range}_{anal_contr[0][0]}vs{anal_contr[1][0]}_{ged_filter_labels[i[0]]}-x-{ged_filter_labels[i[1]]}.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # # Compute the difference of the difference (face trials vs object trials / face filter vs. object filter) + # con_dif_c1flt_data = con_condlist[f'{anal_contr[0]}'][:,:len(con_dif_data)//2,:,:] \ + # - con_condlist[f'{anal_contr[1]}'][:,:len(con_dif_data)//2,:,:] + # con_dif_c2flt_data = con_condlist[f'{anal_contr[0]}'][:,len(con_dif_data)//2:,:,:] \ + # - con_condlist[f'{anal_contr[1]}'][:,len(con_dif_data)//2:,:,:] + # con_dif_dif_data = np.mean(con_dif_c1flt_data - con_dif_c2flt_data, + # axis=0) + + # vmin = -.075 + # vmax = .075 + + # # Plot + # for i in range(len(con_dif_dif_data)): + # fig, ax = plt.subplots(figsize=[8,6]) + # sns.heatmap(con_dif_dif_data[i,:,:], + # xticklabels=250, yticklabels=5, + # vmin=vmin, + # vmax=vmax, + # cmap='RdYlBu_r', + # ax=ax) + # ax.set_xticklabels(times[0::250], + # fontsize=8) + # ax.invert_yaxis() + # ax.set_yticklabels(freqs[0::5], rotation='horizontal', + # fontsize=8) + # cbar = ax.collections[0].colorbar + # cbar.ax.tick_params(labelsize=8) + + # plt.xlabel("time (ms)", fontsize=14) + # plt.ylabel("Frequency (Hz)", fontsize=14) + # plt.title(f"{con_method} {anal_contr[0]}-vs-{anal_contr[1]} on {ged_filter_labels[i]}: {anal_contr[0]} vs {anal_contr[1]} filter", fontsize=14, fontweight="bold") + + # # Save figure + # fname_fig = op.join(con_figure_root, + # f"conn-{con_method}_{freq_range}_{anal_contr[0][0]}vs{anal_contr[1][0]}DiffDiff_{ged_filter_labels[i]}.png") + # fig.savefig(fname_fig, dpi=300) + # plt.close(fig) + + + ## Stimulus vs. task: Permutation analysis + + if task_rel == ["Relevant non-target", "Irrelevant"]: + # Set test params + pval = 0.05 # arbitrary + n_observations = len(sub_list) + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + # Loop over indices + t_obs_all = [] + clusters_all = [] + cluster_p_values_all = [] + p_values_all = [] + for i in indices_comb: + print(f'\nTesting clursters for {ged_filter_labels[i[0]]} - {ged_filter_labels[i[1]]}') + + # Get data (subjects) × time × space + Xsti = con_dif["['face', 'object']"][:,indices_comb.index(i),:,:] + Xtas = con_dif["['relev', 'irrel']"][:,indices_comb.index(i),:,:] + + # Run permutation analysis + t_obs, clusters, cluster_p_values, H0 = \ + mne.stats.permutation_cluster_1samp_test( + Xsti - Xtas, + threshold=thresh, + out_type='mask') + + # Append results to list + t_obs_all.append(t_obs) + clusters_all.append(clusters) + cluster_p_values_all.append(cluster_p_values) + p_values_all.append(cluster_p_values) + + # Select the clusters that are statistically significant at p < 0.05 + good_clusters_all = [] + for clusters, cluster_p_values in zip(clusters_all, cluster_p_values_all): + good_clusters_idx = np.where(cluster_p_values < 0.05)[0] + good_clusters = [clusters[idx] for idx in good_clusters_idx] + good_clusters_all.append(good_clusters) + + # Save significant clusters + bids_path_con = bids_path_con.copy().update( + subject=f"groupphase{phase}", + suffix=f"desc-gnw-pfc-ged,{con_method},{freq_range},stim_vs_relev_clusters", + extension=".pkl", + check=False) + + with open(bids_path_con.fpath, 'wb') as file: + pickle.dump(good_clusters_all, file) + + + ## Plotting + + # Compute difference between stimulus and task effects + con_dif_dif_data = np.mean( + con_dif["['face', 'object']"] - con_dif["['relev', 'irrel']"], + axis=0) + + vmin = -.075 + vmax = .075 + + # Plot + for i in indices_comb: + print(f'\nPlotting {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}...') + # Get data + data = con_dif_dif_data[indices_comb.index(i),:,:] + # extent = [0,len(times),0,len(freqs)] + extent = list(map(int, [times[0],times[-1],freqs[0],freqs[-1]])) + sig_mask = np.any(good_clusters_all[indices_comb.index(i)], axis=0) + masked_data = np.ma.masked_where(sig_mask == 0, data) + + # Open figure + fig, ax = plt.subplots(figsize=[8,6]) + + # Plot all data + ax.imshow(data, + cmap='RdYlBu_r', + extent=extent, + origin="lower", + alpha=.4, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Plot masked data + im = ax.imshow(masked_data, + cmap='RdYlBu_r', + origin='lower', + extent=extent, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Draw contour + if np.any(sig_mask == 1): + ax.contour(sig_mask, + levels=[0, 1], + colors="k", + origin="lower", + extent=extent) + + ax.set_yticklabels(freqs[0::5]) + + cbar = plt.colorbar(im, ax=ax) + cbar.ax.tick_params(labelsize=8) + + plt.xlabel("time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"{con_method} on {ged_filter_labels[i[0]]}-{ged_filter_labels[i[1]]}: stimuus vs task", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn-{con_method}_{freq_range}_svst_{ged_filter_labels[i[0]]}-x-{ged_filter_labels[i[1]]}.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + +if __name__ == '__main__': + # subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + # visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + connectivity_ga(sub_list, visit_id) diff --git a/connectivity/Co02c_connect_dfc_ga.py b/connectivity/Co02c_connect_dfc_ga.py new file mode 100644 index 0000000..c80eb02 --- /dev/null +++ b/connectivity/Co02c_connect_dfc_ga.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co02. Grand-average connectivity +=================================== + +Compute the grand average for the connectivity analysis + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import numpy as np +import os +import os.path as op +import matplotlib.pyplot as plt +import argparse +from scipy import stats as stats +import pickle + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--method', + type=str, + default='dfc', + help='method used to measure connectivity (e.g. "coh")') +opt=parser.parse_args() + + +# Set params +visit_id = "V1" +con_method = opt.method + +task_rel = ["Irrelevant"] + +surrogate = False +remove_evoked = True + +debug = False + +# Define vars for output folder name +if task_rel == ["Relevant non-target", "Irrelevant"]: + t = "" +elif task_rel == ["Irrelevant"]: + t = "_irr" +elif task_rel == ["Relevant non-target"]: + t = "_rel" +if remove_evoked: + e = "_no-evoked" +else: + e = "" +if surrogate: + s = "_surrogate" +else: + s = "" + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def connectivity_dfc_ga(sub_list, visit_id): + # Set path to preprocessing derivatives and create the related folders + con_deriv_root = op.join(bids_root, "derivatives", "connectivity"+t, "_dfc", e, s) + if not op.exists(con_deriv_root): + raise ValueError("Error: connectivity derivatives folder does not exist") + con_figure_root = op.join(con_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures", + con_method) + if not op.exists(con_figure_root): + os.makedirs(con_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + print('\nCompute connectivity grandaverage with method:', con_method) + + # Load times and freq + bids_path_times = mne_bids.BIDSPath( + root=con_deriv_root, + subject=sub_list[0], + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f"desc-{con_method}_times", + extension='.npy', + check=False) + times = np.load(bids_path_times.fpath) + + bids_path_freqs = bids_path_times.copy().update( + root=con_deriv_root, + subject=sub_list[0], + suffix=f"desc-{con_method}_freqs") + freqs = np.load(bids_path_freqs.fpath) + + # Create indices of label-to-label couples for which to compute connectivity + n_labels = 2 + roi_labels = ['pfc','v1v2','face filter','object filter'] + indices = (np.concatenate([range(0,n_labels),range(0,n_labels)]), + np.array([n_labels]*len(range(0,n_labels)) + [n_labels+1]*len(range(0,n_labels)))) + indices_comb = [[i,j] for i,j in zip(indices[0], indices[1])] + + con_condlist = {} + # Loop over conditions + for cond_name in ["object", "face"]: + print("\n Running condition " + cond_name + "\n") + + # Load indivisual results + con_all = [] + for sub in sub_list: + print("subject id:", sub) + + # Load connectivity data + bids_path_con = mne_bids.BIDSPath( + root=con_deriv_root, + subject=sub, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f"desc-{con_method}_{cond_name}_con", + extension='.npy', + check=False) + + con_all.append(np.load(bids_path_con.fpath)) + + # Averaged over participants + con_all = np.array(con_all) + con_all_ga = np.mean(con_all, axis=0) + + # Save grandaverage + bids_path_con = bids_path_con.update( + subject=f"groupphase{phase}") + np.save(bids_path_con.fpath, con_all_ga) + + # Append to list + con_condlist[cond_name] = con_all + + # # Plot single condition + analysis_time = [round(x,3) for x in times] + freqs = [int(x) for x in freqs] + extent = list([analysis_time[0],analysis_time[-1],1,len(freqs)]) + + vmin = -5 + vmax = 5 + + for i in indices_comb: + # Get data and do z-scoring by frequencies + data = stats.zscore(con_all_ga[indices_comb.index(i),:,:], axis=1) + + # Plot + fig, ax = plt.subplots(figsize=[8,6]) + im = ax.imshow(data, + cmap='RdYlBu_r', + extent=extent, + origin="lower", + aspect='auto', + vmin=vmin, vmax=vmax) + + cbar = plt.colorbar(im, ax=ax) + cbar.ax.tick_params(labelsize=8) + + ax.set_yticklabels(freqs[0::5]) + ax.axhline(freqs.index(30), color='w', lw=4) + + plt.xlabel("Time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"{con_method} on {roi_labels[i[0]]} - {roi_labels[i[1]]}: {cond_name}", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn_{con_method}_{roi_labels[i[0]]}-x-{roi_labels[i[1]]}_{cond_name}.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + + # Permutation analysis + + # Set test params + pval = 0.05 # arbitrary + n_observations = len(sub_list) + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + # Loop over indices + t_obs_all = [] + clusters_all = [] + cluster_p_values_all = [] + p_values_all = [] + for i in indices_comb: + print(f'\nTesting clursters for {roi_labels[i[0]]} - {roi_labels[i[1]]}') + + # Get data (subjects) × time × space + Xfac = con_condlist['face'][:,indices_comb.index(i),:,:] + Xobj = con_condlist['object'][:,indices_comb.index(i),:,:] + + # Run permutation analysis + t_obs, clusters, cluster_p_values, H0 = \ + mne.stats.permutation_cluster_1samp_test( + Xfac - Xobj, + threshold=thresh, + out_type='mask') + + # Append results to list + t_obs_all.append(t_obs) + clusters_all.append(clusters) + cluster_p_values_all.append(cluster_p_values) + p_values_all.append(cluster_p_values) + + # Select the clusters that are statistically significant at p < 0.05 + good_clusters_all = [] + for clusters, cluster_p_values in zip(clusters_all, cluster_p_values_all): + good_clusters_idx = np.where(cluster_p_values < 0.05)[0] + good_clusters = [clusters[idx] for idx in good_clusters_idx] + good_clusters_all.append(good_clusters) + + # Save significant clusters + anal_contr = ['face', 'object'] + bids_path_con = bids_path_con.copy().update( + subject=f"groupphase{phase}", + suffix=f"desc-{con_method}_{anal_contr}_clusters", + extension=".pkl", + check=False) + + with open(bids_path_con.fpath, 'wb') as file: + pickle.dump(good_clusters_all, file) + + + # Plotting + + # Compute difference between face and object trials + con_dif_data = np.mean(con_condlist['face'] - con_condlist['object'], + axis=0) + vmin = -.2 + vmax = .2 + + for i in indices_comb: + # Get data + data = con_dif_data[indices_comb.index(i),:,:] + sig_mask = np.any(good_clusters_all[indices_comb.index(i)], axis=0) + masked_data = np.ma.masked_where(sig_mask == 0, data) + + # Plot all data + fig, ax = plt.subplots(figsize=[8,6]) + ax.imshow(data, + cmap='RdYlBu_r', + extent=extent, + origin="lower", + alpha=.4, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Plot masked data + im = ax.imshow(masked_data, + cmap='RdYlBu_r', + origin='lower', + extent=extent, + aspect='auto', + vmin=vmin, vmax=vmax) + + # Draw contour + if np.any(sig_mask == 1): + ax.contour(sig_mask == 0, sig_mask == 0, + colors="k", + origin="lower", + extent=extent) + + cbar = plt.colorbar(im, ax=ax) + cbar.ax.tick_params(labelsize=8) + + ax.set_yticklabels(freqs[0::5]) + ax.axhline(freqs.index(30), color='w', lw=4) + + plt.xlabel("time (ms)", fontsize=14) + plt.ylabel("Frequency (Hz)", fontsize=14) + plt.title(f"{con_method} on {roi_labels[i[0]]} - {roi_labels[i[1]]}: face vs object", fontsize=14, fontweight="bold") + + # Save figure + fname_fig = op.join(con_figure_root, + f"conn-{con_method}_FvsO_{roi_labels[i[0]]}-x-{roi_labels[i[1]]}_FvsO.png") + fig.savefig(fname_fig, dpi=300) + plt.close(fig) + + +if __name__ == '__main__': + # subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + # visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + connectivity_dfc_ga(sub_list, visit_id) diff --git a/ged/Co01_ged_selectivity.py b/ged/Co01_ged_selectivity.py new file mode 100644 index 0000000..016ef01 --- /dev/null +++ b/ged/Co01_ged_selectivity.py @@ -0,0 +1,632 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co00. GED spatial filter +=================================== + +Create category-specific spatial filters through generalized eigendecomposition + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import scipy +import scipy.signal as ss +import argparse + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--labels', + nargs='+', + default=['fusifor'], + help='name of the label to which contrain the spatial filter (e.g., "fusiform"') +parser.add_argument('--parc', + type=str, + default='aparc', + help='name of the parcellation atlas to use for contraining the spatial filter (e.g., "aparc", "aparc.a2009s")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +label_list = opt.labels +label_name = ''.join(label_list) +parc = opt.parc + +act_win_tmin = 0. +act_win_tmax = .5 + +debug = False + +#aparc aparc.a2009s +#fusifor G_oc-temp_lat-fusifor +#inferiortemporal +#lateraloccipital G&S_occipital_inf + + +# Set derivatives paths +prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") +fwd_deriv_root = op.join(bids_root, "derivatives", "forward") +fs_deriv_root = op.join(bids_root, "derivatives", "fs") + +if act_win_tmin == 0. and act_win_tmax == .5: + ged_deriv_root = op.join(bids_root, "derivatives", "ged") +else: + ged_deriv_root = op.join(bids_root, "derivatives", "ged", f"_{act_win_tmin}-{act_win_tmax}") +if not op.exists(ged_deriv_root): + os.makedirs(ged_deriv_root) +ged_figure_root = op.join(ged_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") +if not op.exists(ged_figure_root): + os.makedirs(ged_figure_root) + + +# Set task +if visit_id == "V1": + bids_task = 'dur' +elif visit_id == "V2": + bids_task = 'vg' +# elif visit_id == "V2": #find a better way to set the task in V2 +# bids_task = 'replay' +else: + raise ValueError("Error: could not set the task") + + +# ============================================================================= +# READ DATA +# ============================================================================= + +def read_cogitate_data(subject_id, visit_id): + print("Processing subject: %s" % subject_id) + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick all non-target trials + epochs = epochs['Task_relevance in ["Relevant non-target", "Irrelevant"] and Duration != "500ms"'] + if debug: + epochs = epochs[0:100] # ONLY for DEBUG + + # Load data + epochs.load_data() + + # Pick MEG sensors only #here I combine channel types + epochs = epochs.pick('meg') + + return epochs + + +def select_cond(cond, epochs): + # Select epochs + cond_epochs = epochs['Category == "%s"' % cond] + other_epochs = epochs['Category != "%s"' % cond] + + return cond_epochs, other_epochs + + +# ============================================================================= +# SOURCE MODELLING +# ============================================================================= + +def create_inverse(epochs): + # Apply baseline correction + b_tmin = -.5 + b_tmax = - 0. + baseline = (b_tmin, b_tmax) + epochs.apply_baseline(baseline=baseline) + + # Compute rank + rank = mne.compute_rank(epochs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = mne_bids.BIDSPath( + root=fwd_deriv_root, + subject=subject_id, + datatype='meg', + task=None, + session=visit_id, + suffix='surface_fwd', + extension='.fif', + check=False) + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + active_cov = mne.compute_covariance(epochs, + tmin=0., + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Compute inverse operator (filter) + inverse_operator = mne.minimum_norm.make_inverse_operator( + epochs.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + src = inverse_operator['src'] + + return inverse_operator ,src + + +def apply_inverse(epochs, inverse_operator, label): + # Apply dSPM inverse solution to individual epochs + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + method = "dSPM" + + stcs = mne.minimum_norm.apply_inverse_epochs( + epochs, + inverse_operator, + lambda2, + method, + pick_ori="normal", + label=label, + ) + + return stcs + + +def select_act_win_source(stcs, tmin=0., tmax=.5): + stcs_act = [] + # Loop over epochs + for i in range(len(stcs)): + # Select active time window + stcs_act.append(stcs[i].copy().crop(tmin, tmax)) + + return stcs_act + + +def create_label(label_list, parc): + labels = [] + # Loop over labels + for regexp in label_list: + print("\nReading label "+regexp) + + # Create label for the given region + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + lab = mne.read_labels_from_annot( + "fsaverage", + parc=parc, #aparc aparc.a2009s + regexp=regexp, #'inferiortemporal' + hemi='both', + subjects_dir=fs_deriv_root) + else: + lab = mne.read_labels_from_annot( + "sub-"+subject_id, + parc=parc, #aparc aparc.a2009s + regexp=regexp, #'inferiortemporal' + hemi='both', + subjects_dir=fs_deriv_root) + + # Save label + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=None, + session=visit_id, + suffix=f"desc-{regexp}_label-lh", + extension='.label', + check=False) + lab[0].save(bids_path_ged.fpath) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f"desc-{regexp}_label-rh",) + lab[1].save(bids_path_ged.fpath) + + # Append to labels + labels.append(lab) + + # # Show brain with label areas highlighted #3D plots not working on the hpc + # if os.getlogin() in ['oscfe', 'ferranto', 'FerrantO']: + # if regexp == label_list[0]: + # brain = mne.viz.Brain( + # "sub-"+subject_id, + # subjects_dir=fs_deriv_root) + # brain.add_label(lab[0]) + # brain.add_label(lab[1]) + + # # Save brain figure in different views + # if os.getlogin() in ['oscfe', 'ferranto', 'FerrantO']: + # #lateral + # brain.show_view('lateral') + # brain.save_image(op.join(ged_figure_root, + # f'label_{label_name}_lat.png')) + # #ventral + # brain.show_view('ventral') + # brain.save_image(op.join(ged_figure_root, + # f'label_{label_name}_ven.png')) + # #caudal + # brain.show_view('caudal') + # brain.save_image(op.join(ged_figure_root, + # f'label_{label_name}_cau.png')) + # brain.close() + + # Combine labels + label = np.sum(labels) + + return label + + +# ============================================================================= +# GED +# ============================================================================= + +def comp_cov_stcs(stcs): + # Compute covariance matrices + cov = [] + #loop over trials + for stc in stcs: + #get trial data + data = stc.data + #mean-center + data = data - np.mean(data, axis=1, keepdims=True) + #compute covariance + cov_trial = data@data.T / (len(data[0]) - 1) + #append results to list + cov.append(cov_trial) + + return cov + + +def clean_and_average_cov(cov): + # Clean covariance data from outliers and average trials + + # Average covariance over trials + cov_m = np.mean(cov, axis=0) + + # Loop over trials + dists = [] + for i in range(len(cov)): + # Get data + tcov = cov[i] + # Compute euclidean distance + dists.append(np.sqrt(np.sum((tcov.reshape(1,-1)-cov_m.reshape(1,-1))**2))) + + # Compute z-scored distance + dists_Z = (dists-np.mean(dists)) / np.std(dists) + + # Average trial-covariances together, excluding outliers + cov_avg = np.mean( np.asarray(cov)[dists_Z<3] ,axis=0) + + return cov_avg + + +def apply_reg(cov): + # Apply regularization + gamma = .01 + cov_r = cov*(1-gamma) + gamma * np.mean(scipy.linalg.eigh(cov)[0]) * np.eye(len(cov)) + + return cov_r + + +def plot_cov(covSm, covRm, cond): + # Plot covariance matrices + fig,axs = plt.subplots(1,3,figsize=(8,4)) + # A matrix + axs[0].imshow(covSm,vmin=np.min(covSm),vmax=np.max(covSm),cmap='jet') + axs[0].set_title('S matrix') + # B matrix + axs[1].imshow(covRm,vmin=np.min(covRm),vmax=np.max(covRm),cmap='jet') + axs[1].set_title('R matrix') + # R^{-1}S + cov_sxinvr = np.linalg.inv(covRm)@covSm + axs[2].imshow(cov_sxinvr,vmin=np.min(cov_sxinvr),vmax=np.max(cov_sxinvr),cmap='jet') + axs[2].set_title('$R^{-1}S$ matrix') + plt.tight_layout() + + # Save figure + fname_fig = op.join(ged_figure_root, + f"ged_covariace_matrices_{label_name}_{cond}.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +def comp_ged(covAm, covBm, cond): + # Run GED + evals,evecs = scipy.linalg.eigh(covAm,covBm) + + # Sort eigenvalues/vectors + sidx = np.argsort(evals)[::-1] + evals = evals[sidx] + evecs = evecs[:,sidx] + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f'desc-{label_name},{cond}_evals', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, evals) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{label_name},{cond}_evecs',) + np.save(bids_path_ged.fpath, evecs) + + return evals, evecs + + +def plot_ged_evals(evals, cond): + # Plot the eigenspectrum + fig = plt.figure() + plt.plot(evals[0:20],'s-',markersize=15,markerfacecolor='k') + plt.title('GED eigenvalues') + plt.xlabel('Component number') + plt.ylabel('Power ratio (norm-$\lambda$)') + + # Save figure + fname_fig = op.join(ged_figure_root, + f"ged_eigenvalues_sorted_{label_name}_{cond}.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +def create_ged_spatial_filter(evecs, cond): + # Filter forward model + filt_topo = evecs[:,0] + + # Eigenvector sign + se = np.argmax(np.abs( filt_topo )) + filt_topo = filt_topo * np.sign(filt_topo[se]) + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f'desc-{label_name},{cond}_filttopo', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, filt_topo) + + return filt_topo + + +def get_ged_time_course(stcs, evecs, cond): + comp_ts = [] + # Loop over epochs + for i in range(len(stcs)): + + # Get data + data = stcs[i].data + + # Apply GED filter + comp_ts.append(evecs[:,0].T @ data) + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f'desc-{label_name},{cond}_compts', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, comp_ts) + + return comp_ts + + +def lowpass_filter(data, order=6, fs=1000.0, cutoff=30.0): + # Low-pass filter the data + b, a = ss.butter(order, + cutoff, + fs=fs, + btype='low', + analog=False) + + data_filt = ss.lfilter(b, a, data) + return data_filt + + +def plot_ged_result(stcs, comp_ts, comp_ts_other, cond): + # Low-pass filter + comp_ts = lowpass_filter(comp_ts) + comp_ts_other = lowpass_filter(comp_ts_other) + + # Compute root mean square + comp_ts_rms = np.sqrt((np.array(comp_ts)**2).mean(axis=0)) + comp_ts_other_rms = np.sqrt((np.array(comp_ts_other)**2).mean(axis=0)) + + # Baseline correction + imin = (np.abs(stcs[0].times - -.1)).argmin() #here I subtract a negative value + imax = (np.abs(stcs[0].times - 0.)).argmin() + + mean_ts = np.mean(comp_ts_rms[..., imin:imax], axis=-1, keepdims=True) + comp_ts_rms -= mean_ts + comp_ts_rms /= mean_ts + + mean_ts_other = np.mean(comp_ts_other_rms[..., imin:imax], axis=-1, keepdims=True) + comp_ts_other_rms -= mean_ts_other + comp_ts_other_rms /= mean_ts_other + + # Crop edges + tmin = (np.abs(stcs[0].times - -.5)).argmin() #here I subtract a negative value + tmax = (np.abs(stcs[0].times - 2.)).argmin() + + comp_ts_rms = comp_ts_rms[tmin:tmax] + comp_ts_other_rms = comp_ts_other_rms[tmin:tmax] + times = stcs[0].times[tmin:tmax] + + # Set labels + if cond == 'face': + color_cond = 'blue' + cond_other = 'object' + color_other = 'orange' + elif cond == 'object': + color_cond = 'orange' + cond_other = 'face' + color_other = 'blue' + + # Plot filter time course + fig = plt.figure() + plt.plot(times, comp_ts_rms, + label=cond, color=color_cond) + plt.plot(times, comp_ts_other_rms, + label=cond_other, color=color_other) + plt.legend() + plt.title("GED spatial filters' activity") + plt.xlabel('time (sec)') + plt.ylabel('RMS amplitude (a.u.)') + + # Save figure + fname_fig = op.join(ged_figure_root, + f"ged_filter_ts_{label_name}_{cond}.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +# ============================================================================= +# RUN +# ============================================================================= + +if __name__ == '__main__': + # Read epoched data + epochs = read_cogitate_data(subject_id, visit_id) + + # Select conditions of interest #try faces vs. objects + fac_epochs, nofac_epochs = select_cond("face", epochs) + obj_epochs, noobj_epochs = select_cond("object", epochs) + + # Run source modeling (MNE-dSPM) + + # Create inverse solution + inverse_operator, src = create_inverse(epochs) + + # Create label for interiortemporal cortex + label = create_label(label_list, + parc=parc) + + # Apply inverse solution + stcs_fac = apply_inverse(fac_epochs, + inverse_operator, + label=label) + stcs_nofac = apply_inverse(nofac_epochs, + inverse_operator, + label=label) + stcs_obj = apply_inverse(obj_epochs, + inverse_operator, + label=label) + stcs_noobj = apply_inverse(noobj_epochs, + inverse_operator, + label=label) + + # Run GED + + # Select activation (i.e., stimulus presentation) window + stcs_fac_act = select_act_win_source(stcs_fac, + tmin=act_win_tmin, + tmax=act_win_tmax) + stcs_nofac_act = select_act_win_source(stcs_nofac, + tmin=act_win_tmin, + tmax=act_win_tmax) + stcs_obj_act = select_act_win_source(stcs_obj, + tmin=act_win_tmin, + tmax=act_win_tmax) + stcs_noobj_act = select_act_win_source(stcs_noobj, + tmin=act_win_tmin, + tmax=act_win_tmax) + + # Compute covariance + cov_fac = comp_cov_stcs(stcs_fac_act) + cov_nofac = comp_cov_stcs(stcs_nofac_act) + cov_obj = comp_cov_stcs(stcs_obj_act) + cov_noobj = comp_cov_stcs(stcs_noobj_act) + + # Remove outliers and average + cov_fac = clean_and_average_cov(cov_fac) + cov_nofac = clean_and_average_cov(cov_nofac) + cov_obj = clean_and_average_cov(cov_obj) + cov_noobj = clean_and_average_cov(cov_noobj) + + # Apply regularization + cov_nofac = apply_reg(cov_nofac) + cov_noobj = apply_reg(cov_noobj) + + # Plot covariance + plot_cov(cov_fac, cov_nofac, "face") + plot_cov(cov_obj, cov_noobj, "object") + + # Run GED + evals_fac, evecs_fac = comp_ged(cov_fac, cov_nofac, "face") + evals_obj, evecs_obj = comp_ged(cov_obj, cov_noobj, "object") + + # Plot GED eigenvalues + plot_ged_evals(evals_fac, "face") + plot_ged_evals(evals_obj, "object") + + # Create GED spatial filter + filt_topo_fac = create_ged_spatial_filter(evecs_fac, "face") + filt_topo_obj = create_ged_spatial_filter(evecs_obj, "object") + + # Get GED component time course + comp_ts_fac = get_ged_time_course(stcs_fac, evecs_fac, "facFilt_facCond") + comp_ts_obj = get_ged_time_course(stcs_obj, evecs_obj, "objFilt_objCond") + comp_ts_fac_on_obj = get_ged_time_course(stcs_obj, evecs_fac, "facFilt_objCond") + comp_ts_obj_on_fac = get_ged_time_course(stcs_fac, evecs_obj, "objFilt_facCond") + + # Plot GED spatial filter time course + plot_ged_result(stcs_fac, comp_ts_fac, comp_ts_fac_on_obj, "face") + plot_ged_result(stcs_obj, comp_ts_obj, comp_ts_obj_on_fac, "object") diff --git a/ged/Co02_ged_pfc.py b/ged/Co02_ged_pfc.py new file mode 100644 index 0000000..717abb3 --- /dev/null +++ b/ged/Co02_ged_pfc.py @@ -0,0 +1,564 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co00. GED spatial filter +=================================== + +Create prefrontal spatial filters through generalized eigendecomposition + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import scipy +import scipy.signal as ss +import argparse +import json + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +subject_id = opt.sub +visit_id = opt.visit +# label_list = opt.labels +# label_name = ''.join(label_list) +# parc = opt.parc + +debug = False + + +# Set derivatives paths +prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") +fwd_deriv_root = op.join(bids_root, "derivatives", "forward") +fs_deriv_root = op.join(bids_root, "derivatives", "fs") + +ged_deriv_root = op.join(bids_root, "derivatives", "ged") +if not op.exists(ged_deriv_root): + os.makedirs(ged_deriv_root) +ged_figure_root = op.join(ged_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") +if not op.exists(ged_figure_root): + os.makedirs(ged_figure_root) + + +# Set task +if visit_id == "V1": + bids_task = 'dur' +elif visit_id == "V2": + bids_task = 'vg' +# elif visit_id == "V2": #find a better way to set the task in V2 +# bids_task = 'replay' +else: + raise ValueError("Error: could not set the task") + + +# ============================================================================= +# READ DATA +# ============================================================================= + +def read_cogitate_data(subject_id, visit_id): + print("Processing subject: %s" % subject_id) + + # Read epoched data + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs(bids_path_epo.fpath, + preload=False) + + # Pick all non-target trials + epochs = epochs['Task_relevance in ["Relevant non-target", "Irrelevant"] and Duration != "500ms"'] + if debug: + epochs = epochs[0:100] # ONLY for DEBUG + + # Load data + epochs.load_data() + + # Pick MEG sensors only #here I combine channel types + epochs = epochs.pick('meg') + + return epochs + + +def select_time_window(epochs): + # Select epochs + epochs_acti = epochs.copy().crop(0., .5) + epochs_base = epochs.copy().crop(-.501, -.001) + + return epochs_acti, epochs_base + + +# ============================================================================= +# SOURCE MODELLING +# ============================================================================= + +def create_gnw_label(): + # Set path + rois_deriv_root = op.join(bids_root, "derivatives", "roilabel") + + # Read labels from FS parc + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + labels_atlas = mne.read_labels_from_annot( + "fsaverage", + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + else: + labels_atlas = mne.read_labels_from_annot( + "sub-"+subject_id, + parc='aparc.a2009s', + subjects_dir=fs_deriv_root) + + # labels_atlas_names = [l.name for l in labels_atlas] + + # Read GNW and IIT ROI list + f = open(op.join(rois_deriv_root, + 'iit_gnw_rois.json')) + gnw_iit_rois = json.load(f) + + # Create labels for selected ROIs + labels = {} + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + for lab in gnw_iit_rois['surf_labels']['gnw']: + if (lab.find('&') != -1): + lab = lab.replace('&','_and_') + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"] + labels["gnw_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"] + else: + for lab in gnw_iit_rois['surf_labels']['gnw']: + print(lab) + labels["gnw_"+lab+"_lh"] = [l for l in labels_atlas if l.name == lab+"-lh"][0] + labels["gnw_"+lab+"_rh"] = [l for l in labels_atlas if l.name == lab+"-rh"][0] + + # # Show brain with label areas highlighted #3D plots not working on the hpc + # if os.getlogin() in ['oscfe', 'ferranto', 'FerrantO']: + # brain = mne.viz.Brain( + # "sub-"+subject_id, + # subjects_dir=fs_deriv_root) + # for n, l in labels.items(): + # brain.add_label(l, color='g') + + # # Save brain figure in different views + # #lateral + # brain.show_view('lateral') + # brain.save_image(op.join(ged_figure_root, + # 'label_gnw_pfc_lat.png')) + # #ventral + # brain.show_view('ventral') + # brain.save_image(op.join(ged_figure_root, + # 'label_gnw_pfc_ven.png')) + # #caudal + # brain.show_view('caudal') + # brain.save_image(op.join(ged_figure_root, + # 'label_gnw_pfc_cau.png')) + # brain.close() + + # Merge all labels in a single one separatelly for GNW and IIT + label = np.sum([l for l_name, l in labels.items() if 'gnw' in l_name]) + + return label + + +def create_inverse(epochs): + # Apply baseline correction + b_tmin = -.501 + b_tmax = -.001 + baseline = (b_tmin, b_tmax) + epochs.apply_baseline(baseline=baseline) + + # Compute rank + rank = mne.compute_rank(epochs, + tol=1e-6, + tol_kind='relative') + + # Read forward model + bids_path_fwd = mne_bids.BIDSPath( + root=fwd_deriv_root, + subject=subject_id, + datatype='meg', + task=None, + session=visit_id, + suffix='surface_fwd', + extension='.fif', + check=False) + fwd = mne.read_forward_solution(bids_path_fwd.fpath) + + # Compute covariance matrices + base_cov = mne.compute_covariance(epochs, + tmin=b_tmin, + tmax=b_tmax, + method='empirical', + rank=rank) + active_cov = mne.compute_covariance(epochs, + tmin=0., + tmax=None, + method='empirical', + rank=rank) + common_cov = base_cov + active_cov + + # Compute inverse operator (filter) + inverse_operator = mne.minimum_norm.make_inverse_operator( + epochs.info, + fwd, + common_cov, + loose=.2, + depth=.8, + fixed=False, + rank=rank, + use_cps=True) + + src = inverse_operator['src'] + + return inverse_operator ,src + + +def apply_inverse(epochs, inverse_operator, label, desc): + # Apply dSPM inverse solution to individual epochs + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + method = "dSPM" + + stcs = mne.minimum_norm.apply_inverse_epochs( + epochs, + inverse_operator, + lambda2, + method, + pick_ori="normal", + label=label, + ) + + # Plot evoked averaged over vertices + data = np.mean(np.array([stc.data for stc in stcs]), axis=1) + times = stcs[0].times + evk_m = np.mean(data, axis=0) + evk_std = np.std(data,axis=0) + + plt.plot(times, evk_m) + plt.fill_between(times, evk_m-evk_std, evk_m+evk_std, color='b', alpha=.1) + + # Save figure + fname_fig = op.join(ged_figure_root, + f"ged_stc_evoked_{desc}_gnw.png") + plt.savefig(fname_fig) + plt.close() + + return stcs + + +# ============================================================================= +# GED +# ============================================================================= + +def comp_cov_stcs(stcs): + # Compute covariance matrices + cov = [] + #loop over trials + for stc in stcs: + #get trial data + data = stc.data + #mean-center + data = data - np.mean(data, axis=1, keepdims=True) + #compute covariance + cov_trial = data@data.T / (len(data[0]) - 1) + #append results to list + cov.append(cov_trial) + + return cov + + +def clean_and_average_cov(cov): + # Clean covariance data from outliers and average trials + + # Average covariance over trials + cov_m = np.mean(cov, axis=0) + + # Loop over trials + dists = [] + for i in range(len(cov)): + # Get data + tcov = cov[i] + # Compute euclidean distance + dists.append(np.sqrt(np.sum((tcov.reshape(1,-1)-cov_m.reshape(1,-1))**2))) + + # Compute z-scored distance + dists_Z = (dists-np.mean(dists)) / np.std(dists) + + # Average trial-covariances together, excluding outliers + cov_avg = np.mean( np.asarray(cov)[dists_Z<3] ,axis=0) + + return cov_avg + + +def apply_reg(cov): + # Apply regularization + gamma = .01 + cov_r = cov*(1-gamma) + gamma * np.mean(scipy.linalg.eigh(cov)[0]) * np.eye(len(cov)) + + return cov_r + + +def plot_cov(covSm, covRm): + # Plot covariance matrices + fig,axs = plt.subplots(1,3,figsize=(8,4)) + # A matrix + axs[0].imshow(covSm,vmin=np.min(covSm),vmax=np.max(covSm),cmap='jet') + axs[0].set_title('S matrix') + # B matrix + axs[1].imshow(covRm,vmin=np.min(covRm),vmax=np.max(covRm),cmap='jet') + axs[1].set_title('R matrix') + # R^{-1}S + cov_sxinvr = np.linalg.inv(covRm)@covSm + axs[2].imshow(cov_sxinvr,vmin=np.min(cov_sxinvr),vmax=np.max(cov_sxinvr),cmap='jet') + axs[2].set_title('$R^{-1}S$ matrix') + plt.tight_layout() + + # Save figure + fname_fig = op.join(ged_figure_root, + "ged_covariace_matrices_gnw.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +def comp_ged(covAm, covBm): + # Run GED + evals,evecs = scipy.linalg.eigh(covAm,covBm) + + # Sort eigenvalues/vectors + sidx = np.argsort(evals)[::-1] + evals = evals[sidx] + evecs = evecs[:,sidx] + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='desc-gnw_evals', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, evals) + + bids_path_ged = bids_path_ged.copy().update( + suffix='desc-gnw_evecs',) + np.save(bids_path_ged.fpath, evecs) + + return evals, evecs + + +def plot_ged_evals(evals): + # Plot the eigenspectrum + fig = plt.figure() + plt.plot(evals[0:20],'s-',markersize=15,markerfacecolor='k') + plt.title('GED eigenvalues') + plt.xlabel('Component number') + plt.ylabel('Power ratio (norm-$\lambda$)') + + # Save figure + fname_fig = op.join(ged_figure_root, + "ged_eigenvalues_sorted_gnw.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +def create_ged_spatial_filter(evecs): + # Filter forward model + filt_topo = evecs[:,0] + + # Eigenvector sign + se = np.argmax(np.abs( filt_topo )) + filt_topo = filt_topo * np.sign(filt_topo[se]) + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='desc-gnw_filttopo', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, filt_topo) + + return filt_topo + + +def get_ged_time_course(stcs, evecs): + comp_ts = [] + # Loop over epochs + for i in range(len(stcs)): + + # Get data + data = stcs[i].data + + # Apply GED filter + comp_ts.append(evecs[:,0].T @ data) + + # Save results + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='desc-gnw_compts', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, comp_ts) + + return comp_ts + + +def lowpass_filter(data, order=6, fs=1000.0, cutoff=30.0): + # Low-pass filter the data + b, a = ss.butter(order, + cutoff, + fs=fs, + btype='low', + analog=False) + + data_filt = ss.lfilter(b, a, data) + return data_filt + + +def plot_ged_result(stcs, comp_ts): + # Low-pass filter + comp_ts = lowpass_filter(comp_ts) + + # Compute root mean square + comp_ts_rms = np.sqrt((np.array(comp_ts)**2).mean(axis=0)) + + # Baseline correction + imin = (np.abs(stcs[0].times - -.1)).argmin() #here I subtract a negative value + imax = (np.abs(stcs[0].times - 0.)).argmin() + + mean_ts = np.mean(comp_ts_rms[..., imin:imax], axis=-1, keepdims=True) + comp_ts_rms -= mean_ts + comp_ts_rms /= mean_ts + + # Crop edges + tmin = (np.abs(stcs[0].times - -.5)).argmin() #here I subtract a negative value + tmax = (np.abs(stcs[0].times - 2.)).argmin() + + comp_ts_rms = comp_ts_rms[tmin:tmax] + times = stcs[0].times[tmin:tmax] + + # Plot filter time course + fig = plt.figure() + plt.plot(times, comp_ts_rms) + plt.title("GED spatial filters' activity") + plt.xlabel('time (sec)') + plt.ylabel('RMS amplitude (a.u.)') + + # Save figure + fname_fig = op.join(ged_figure_root, + "ged_filter_ts_gnw.png") + fig.savefig(fname_fig) + plt.close(fig) + + return fig + + +# ============================================================================= +# RUN +# ============================================================================= + +if __name__ == '__main__': + # Read epoched data + epochs = read_cogitate_data(subject_id, visit_id) + + # Select conditions of interest #try faces vs. objects + epochs_acti, epochs_base = select_time_window(epochs) + + # Run source modeling (MNE-dSPM) + + # Create inverse solution + inverse_operator, src = create_inverse(epochs) + + # Create label for interiortemporal cortex + label = create_gnw_label() + + # Apply inverse solution + stcs_acti = apply_inverse(epochs_acti, + inverse_operator, + label=label, + desc='acti') + stcs_base = apply_inverse(epochs_base, + inverse_operator, + label=label, + desc='base') + stcs_whole = apply_inverse(epochs, + inverse_operator, + label=label, + desc='whole') + + # Run GED + + # Compute covariance + cov_acti = comp_cov_stcs(stcs_acti) + cov_base = comp_cov_stcs(stcs_base) + + # Remove outliers and average + cov_acti = clean_and_average_cov(cov_acti) + cov_base = clean_and_average_cov(cov_base) + + # Apply regularization + cov_base = apply_reg(cov_base) + + # Plot covariance + plot_cov(cov_acti, cov_base) + + # Run GED + evals_gnw_pfc, evecs_gnw_pfc = comp_ged(cov_acti, cov_base) + + # Plot GED eigenvalues + plot_ged_evals(evals_gnw_pfc) + + # Create GED spatial filter + filt_topo_gnw_pfc = create_ged_spatial_filter(evecs_gnw_pfc) + + # Get GED component time course + comp_ts_gnw_pfc = get_ged_time_course(stcs_whole, evecs_gnw_pfc) + + # Plot GED spatial filter time course + plot_ged_result(stcs_whole, comp_ts_gnw_pfc) diff --git a/ged/Co03_ged_selectivity_ga.py b/ged/Co03_ged_selectivity_ga.py new file mode 100644 index 0000000..df08d8d --- /dev/null +++ b/ged/Co03_ged_selectivity_ga.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Co00. GED spatial filter +=================================== + +Create category-specific spatial filters through generalized eigendecomposition + +@author: Oscar Ferrante oscfer88@gmail.com +""" + +import os +import os.path as op +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as ss +from scipy import stats as stats +import argparse +import pickle + +import mne +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +parser=argparse.ArgumentParser() +parser.add_argument('--labels', + nargs='+', + default=['fusifor'], + help='name of the label to which contrain the spatial filter (e.g., "fusiform"') +parser.add_argument('--parc', + type=str, + default='aparc', + help='name of the parcellation atlas to use for contraining the spatial filter (e.g., "aparc", "aparc.a2009s")') +# parser.add_argument('--bids_root', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', +# help='Path to the BIDS root directory') +opt=parser.parse_args() + + +# Set params +visit_id = 'V1' +label_list = opt.labels +label_name = ''.join(label_list) +parc = opt.parc + +act_win_tmin = 0. +act_win_tmax = .5 + +debug = False + + +# Set participant list +phase = 3 + +if debug: + sub_list = ["SA124", "SA124"] +else: + # Read the .txt file + f = open(op.join(bids_root, + f'participants_MEG_phase{phase}_included.txt'), 'r').read() + # Split text into list of elemetnts + sub_list = f.split("\n") + + +def ged_ga(): + # Set derivatives paths + if act_win_tmin == 0. and act_win_tmax == .5: + ged_deriv_root = op.join(bids_root, "derivatives", "ged") + else: + ged_deriv_root = op.join(bids_root, "derivatives", "ged", f"_{act_win_tmin}-{act_win_tmax}") + + ged_figure_root = op.join(ged_deriv_root, + f"sub-groupphase{phase}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(ged_figure_root): + os.makedirs(ged_figure_root) + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + # elif visit_id == "V2": #find a better way to set the task in V2 + # bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + # Loop over subjects + ged_facFilt_facCond_ts = [] + ged_objFilt_objCond_ts = [] + ged_objFilt_facCond_ts = [] + ged_facFilt_objCond_ts = [] + for sub in sub_list: + print("\nLoading data for subject:", sub, "\nand label(s):", label_name) + + # Read GED filters' time courses + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=sub, + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f'desc-{label_name},facFilt_facCond_compts', + extension='.npy', + check=False) + ged_facFilt_facCond_ts.append(np.load(bids_path_ged.fpath)) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{label_name},objFilt_objCond_compts') + ged_objFilt_objCond_ts.append(np.load(bids_path_ged.fpath)) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{label_name},objFilt_facCond_compts') + ged_objFilt_facCond_ts.append(np.load(bids_path_ged.fpath)) + + bids_path_ged = bids_path_ged.copy().update( + suffix=f'desc-{label_name},facFilt_objCond_compts') + ged_facFilt_objCond_ts.append(np.load(bids_path_ged.fpath)) + + # Average trials within participants (ev=evoked) + ged_facFilt_facCond_ts_ev = [np.mean(ged, axis=0) for ged in ged_facFilt_facCond_ts] + ged_objFilt_objCond_ts_ev = [np.mean(ged, axis=0) for ged in ged_objFilt_objCond_ts] + ged_objFilt_facCond_ts_ev = [np.mean(ged, axis=0) for ged in ged_objFilt_facCond_ts] + ged_facFilt_objCond_ts_ev = [np.mean(ged, axis=0) for ged in ged_facFilt_objCond_ts] + + # Low-pass filter the data + print("\nLow-pass filtering the data...") + order = 6 + fs = 1000.0 # sample rate (Hz) + cutoff = 30.0 + b, a = ss.butter(order, + cutoff, + fs=fs, + btype='low', + analog=False) + + ged_facFilt_facCond_ts_lp = [ss.lfilter(b, a, ged) for ged in ged_facFilt_facCond_ts_ev] + ged_objFilt_objCond_ts_lp = [ss.lfilter(b, a, ged) for ged in ged_objFilt_objCond_ts_ev] + ged_objFilt_facCond_ts_lp = [ss.lfilter(b, a, ged) for ged in ged_objFilt_facCond_ts_ev] + ged_facFilt_objCond_ts_lp = [ss.lfilter(b, a, ged) for ged in ged_facFilt_objCond_ts_ev] + + # Compute root mean square + print("\nComputing RMS...") + ged_facFilt_facCond_ts_rms = [np.sqrt((np.array(ged)**2)) for ged in ged_facFilt_facCond_ts_lp] + ged_objFilt_objCond_ts_rms = [np.sqrt((np.array(ged)**2)) for ged in ged_objFilt_objCond_ts_lp] + ged_objFilt_facCond_ts_rms = [np.sqrt((np.array(ged)**2)) for ged in ged_objFilt_facCond_ts_lp] + ged_facFilt_objCond_ts_rms = [np.sqrt((np.array(ged)**2)) for ged in ged_facFilt_objCond_ts_lp] + + # Baseline correction + print("\nCorrecting for the baseline...") + baseline_win = [-.5, 0] + times = np.arange(-1, 2.501, .001) + + imin = (np.abs(times - baseline_win[0])).argmin() + imax = (np.abs(times - baseline_win[1])).argmin() + + ged_facFilt_facCond_ts_bc = [] + for ged in ged_facFilt_facCond_ts_rms: + mean_ts = np.mean(ged[..., imin:imax], axis=-1, keepdims=True) + ged -= mean_ts + ged /= mean_ts + ged_facFilt_facCond_ts_bc.append(ged) + + ged_objFilt_objCond_ts_bc = [] + for ged in ged_objFilt_objCond_ts_rms: + mean_ts = np.mean(ged[..., imin:imax], axis=-1, keepdims=True) + ged -= mean_ts + ged /= mean_ts + ged_objFilt_objCond_ts_bc.append(ged) + + ged_objFilt_facCond_ts_bc = [] + for ged in ged_objFilt_facCond_ts_rms: + mean_ts = np.mean(ged[..., imin:imax], axis=-1, keepdims=True) + ged -= mean_ts + ged /= mean_ts + ged_objFilt_facCond_ts_bc.append(ged) + + ged_facFilt_objCond_ts_bc = [] + for ged in ged_facFilt_objCond_ts_rms: + mean_ts = np.mean(ged[..., imin:imax], axis=-1, keepdims=True) + ged -= mean_ts + ged /= mean_ts + ged_facFilt_objCond_ts_bc.append(ged) + + # Average over participants + print("\nComputing grandaverage...") + ged_facFilt_facCond_ts_ga = np.mean(ged_facFilt_facCond_ts_bc, axis=0) + ged_objFilt_objCond_ts_ga = np.mean(ged_objFilt_objCond_ts_bc, axis=0) + ged_objFilt_facCond_ts_ga = np.mean(ged_objFilt_facCond_ts_bc, axis=0) + ged_facFilt_objCond_ts_ga = np.mean(ged_facFilt_objCond_ts_bc, axis=0) + + # Save averaged data + bids_path_ged = mne_bids.BIDSPath( + root=ged_deriv_root, + subject=f"groupphase{phase}", + datatype='meg', + task=bids_task, + session=visit_id, + suffix=f'desc-{label_name}_compts', + extension='.npy', + check=False) + np.save(bids_path_ged.fpath, + np.concatenate( + [ged_facFilt_facCond_ts_ga, + ged_objFilt_objCond_ts_ga, + ged_objFilt_facCond_ts_ga, + ged_facFilt_objCond_ts_ga])) + + # Set limits used to crop edges in figures + print("\nRemoving edges...") + t_win = [-.5, 2.] + + # Plot filter time course + print("\nPlotting...") + fig, axs = plt.subplots(2) + axs[0].plot(times, ged_facFilt_facCond_ts_ga, + label='face', color='b', linestyle='-') + axs[0].plot(times, ged_facFilt_objCond_ts_ga, + label='object', color='r', linestyle='-') + axs[1].plot(times, ged_objFilt_facCond_ts_ga, + label='face', color='b', linestyle='-') + axs[1].plot(times, ged_objFilt_objCond_ts_ga, + label='object', color='r', linestyle='-') + + axs[0].set_title('Grandaverage evoked-activity in face-selective filter') + axs[1].set_title('Grandaverage evoked-activity in object-selective filter') + for ax in axs: + ax.legend() + ax.legend() + ax.set_xlabel('time (sec)') + ax.set_ylabel('RMS amplitude (a.u.)') + ax.axvline(0, color='k', linestyle='--') + ax.set_xlim(t_win) + plt.tight_layout() + + # Save figure + print("\nSaving figures...") + fname_fig = op.join(ged_figure_root, + f"ged_filter_ts_{label_name}.png") + fig.savefig(fname_fig) + plt.close(fig) + + # Compute statistics + pval = 0.05 # arbitrary + n_observations = len(sub_list) + df = n_observations - 1 # degrees of freedom for the test + threshold = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + T_obs_facFilt, clusters_facFilt, cluster_p_values_facFilt, H0_facFilt = \ + mne.stats.permutation_cluster_1samp_test( + np.array(ged_facFilt_facCond_ts_rms) - np.array(ged_facFilt_objCond_ts_rms), + threshold=threshold, + out_type='mask') + + T_obs_objFilt, clusters_objFilt, cluster_p_values_objFilt, H0_objFilt = \ + mne.stats.permutation_cluster_1samp_test( + np.array(ged_objFilt_objCond_ts_rms) - np.array(ged_objFilt_facCond_ts_rms), + threshold=threshold, + out_type='mask') + + # Save significant clusters + bids_path_mask = bids_path_ged.copy().update( + subject=f"groupphase{phase}", + suffix=f"desc-{label_name},facFilt_clusters", + extension=".pkl", + check=False) + with open(bids_path_mask.fpath, 'wb') as file: + pickle.dump(clusters_facFilt, file) + + bids_path_mask = bids_path_mask.update( + suffix=f"desc-{label_name},objFilt_clusters",) + with open(bids_path_mask.fpath, 'wb') as file: + pickle.dump(clusters_objFilt, file) + + # Plot difference between conditions with significant clusters + fig, axs = plt.subplots(2) + axs[0].plot(times, + ged_facFilt_facCond_ts_ga - ged_facFilt_objCond_ts_ga, + color='k', linestyle='-') + axs[1].plot(times, + ged_objFilt_objCond_ts_ga - ged_objFilt_facCond_ts_ga, + color='k', linestyle='-') + + for i_c, c in enumerate(clusters_facFilt): + c = c[0] + if cluster_p_values_facFilt[i_c] < 0.05: + axs[0].axvspan(times[c.start], times[c.stop - 1], + color='r', alpha=0.3) + # else: + # axs[0].axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), + # alpha=0.3) + for i_c, c in enumerate(clusters_objFilt): + c = c[0] + if cluster_p_values_objFilt[i_c] < 0.05: + axs[1].axvspan(times[c.start], times[c.stop - 1], + color='r', alpha=0.3) + # else: + # axs[1].axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), + # alpha=0.3) + + axs[0].set_title('Face>Object evoked-activity in face-selective filter') + axs[1].set_title('Object>Face evoked-activity in object-selective filter') + for ax in axs: + ax.set_xlabel('time (sec)') + ax.set_ylabel('RMS amplitude (a.u.)') + ax.axvline(0, color='k', linestyle='--') + ax.set_xlim(t_win) + plt.tight_layout() + + # Save figure + print("\nSaving figures...") + fname_fig = op.join(ged_figure_root, + f"ged_filter_ts_{label_name}_diff.png") + fig.savefig(fname_fig) + plt.close(fig) + print("\nCompleted!") + + +# ============================================================================= +# RUN +# ============================================================================= + +if __name__ == '__main__': + ged_ga() \ No newline at end of file diff --git a/meeg_environment.txt b/meeg_environment.txt new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/P01_maxwell_filtering.py b/preprocessing/P01_maxwell_filtering.py new file mode 100644 index 0000000..a1cdbfd --- /dev/null +++ b/preprocessing/P01_maxwell_filtering.py @@ -0,0 +1,318 @@ +""" +=================================== +01. Maxwell filter using MNE-python +=================================== + +The data are Maxwell filtered using tSSS/SSS. + +It is critical to mark bad channels before Maxwell filtering. + +@author: Oscar Ferrante oscfer88@gmail.com + +""" # noqa: E501 + +import os.path as op +import os +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import shutil + +from fpdf import FPDF +import mne +from mne.preprocessing import find_bad_channels_maxwell +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +def run_maxwell_filter(subject_id, visit_id, record="run"): + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives and create the related folders + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + if not op.exists(prep_deriv_root): + os.makedirs(prep_deriv_root) + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + if not op.exists(prep_figure_root): + os.makedirs(prep_figure_root) + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + if not op.exists(prep_report_root): + os.makedirs(prep_report_root) + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + if not op.exists(prep_code_root): + os.makedirs(prep_code_root) + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and record in fname: + + # Set run + if "run" in fname: + run = f"{int(fname[-10]):02}" + elif "rest" in fname: + run = None + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + elif "rest" in fname: + bids_task = "rest" + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set split + if len([f for f in os.listdir(data_path) if op.splitext(fname)[0][:-3] in f and f.endswith(".fif")]) > 1: + split = 1 + else: + split = None + + # Set BIDS path + bids_path = mne_bids.BIDSPath( + root=bids_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=run, + session=visit_id, + split=split, + extension='.fif') + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path) + + # Find initial head position + if run in ["01", None]: + destination = raw.info['dev_head_t'] + + # Detect bad channels + raw.info['bads'] = [] + raw_check = raw.copy() + auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell( + raw_check, + cross_talk=bids_path.meg_crosstalk_fpath, + calibration=bids_path.meg_calibration_fpath, + return_scores=True, + verbose=True) + raw.info['bads'].extend(auto_noisy_chs + auto_flat_chs) + + # Mark bad channels in BIDS events + mne_bids.mark_channels(ch_names=raw.info['bads'], + bids_path=bids_path, + status='bad', + verbose=False) + + # Visualize the scoring used to classify channels as noisy or flat + ch_type = 'grad' + fig = viz_badch_scores(auto_scores, ch_type) + fname_fig = op.join(prep_figure_root, + "01_%sr%s_badchannels_%sscore.png" % (bids_task,run,ch_type)) + fig.savefig(fname_fig) + plt.close(fig) + ch_type = 'mag' + fig = viz_badch_scores(auto_scores, ch_type) + fname_fig = op.join(prep_figure_root, + "01_%sr%s_badchannels_%sscore.png" % (bids_task,run,ch_type)) + fig.savefig(fname_fig) + plt.close(fig) + + # Fix Elekta magnetometer coil types + raw.fix_mag_coil_types() + + # Set coordinate frame + if subject_id == 'empty': + coord_frame = 'meg' + else: + coord_frame = 'head' + + # Perform tSSS/SSS and Maxwell filtering + raw_sss = mne.preprocessing.maxwell_filter( + raw, + cross_talk=bids_path.meg_crosstalk_fpath, + calibration=bids_path.meg_calibration_fpath, + st_duration=None, + origin='auto', + destination=destination, #align head location to first run + coord_frame=coord_frame, + verbose=True) + + # Show original and filtered signals + fig = raw.copy().pick(['meg']).plot(duration=5, + start=100, + butterfly=True) + fname_fig = op.join(prep_figure_root, + '01_%sr%s_plotraw.png' % (bids_task,run)) + fig.savefig(fname_fig) + plt.close(fig) + fig = raw_sss.copy().pick(['meg']).plot(duration=5, + start=100, + butterfly=True) + fname_fig = op.join(prep_figure_root, + '01_%sr%s_plotrawsss.png' % (bids_task,run)) + fig.savefig(fname_fig) + plt.close(fig) + + # Show original and filtered power + fig1 = raw.plot_psd(picks = ['meg'],fmin = 1,fmax = 100) + fname_fig1 = op.join(prep_figure_root, + '01_%sr%s_plot_psd_raw100.png' % (bids_task,run)) + fig1.savefig(fname_fig1) + plt.close(fig1) + fig2 = raw_sss.plot_psd(picks = ['meg'],fmin = 1,fmax = 100) + fname_fig2 = op.join(prep_figure_root, + '01_%sr%s_plot_psd_raw100sss.png' % (bids_task,run)) + fig2.savefig(fname_fig2) + plt.close(fig2) + + # Add figures to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Power Spectrum of Raw MEG Data', 'B', ln=1) + pdf.image(fname_fig1, 0, 45, pdf.epw) + pdf.ln(120) + pdf.cell(0, 10, 'Power Spectrum of Filtered MEG Data', 'B', ln=1) + pdf.image(fname_fig2, 0, 175, pdf.epw) + + # Save filtered data + bids_path_sss = bids_path.copy().update( + root=prep_deriv_root, + split=None, + suffix="sss", + check=False) + if not op.exists(bids_path_sss): + bids_path_sss.fpath.parent.mkdir(exist_ok=True, parents=True) + + raw_sss.save(bids_path_sss, overwrite=True) + + # Add note about reconstructed sensors to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, "Reconstructed sensors:") + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'bad MEG sensors: %s' % raw.info['bads'], 'B', ln=1) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Save report + if record == "rest": + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report_rest.pdf')) + else: + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report.pdf')) + + +def viz_badch_scores(auto_scores, ch_type): + fig, ax = plt.subplots(1, 4, figsize=(12, 8)) + fig.suptitle(f'Automated noisy/flat channel detection: {ch_type}', + fontsize=16, fontweight='bold') + + #### Noisy channels #### + ch_subset = auto_scores['ch_types'] == ch_type + ch_names = auto_scores['ch_names'][ch_subset] + scores = auto_scores['scores_noisy'][ch_subset] + limits = auto_scores['limits_noisy'][ch_subset] + bins = auto_scores['bins'] #the windows that were evaluated + + # Label each segment by its start and stop time (3 digits / 1 ms precision) + bin_labels = [f'{start:3.3f} - {stop:3.3f}' + for start, stop in bins] + + # Store data in DataFrame + data_to_plot = pd.DataFrame(data=scores, + columns=pd.Index(bin_labels, name='Time (s)'), + index=pd.Index(ch_names, name='Channel')) + + # First, plot the raw scores + sns.heatmap(data=data_to_plot, + cmap='Reds', + cbar=False, + # cbar_kws=dict(label='Score'), + ax=ax[0]) + [ax[0].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') + for x in range(1, len(bins))] + ax[0].set_title('Noisy: All Scores', fontweight='bold') + + # Second, highlight segments that exceeded the 'noisy' limit + sns.heatmap(data=data_to_plot, + vmin=np.nanmin(limits), + cmap='Reds', + cbar=True, + # cbar_kws=dict(label='Score'), + ax=ax[1]) + [ax[1].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') + for x in range(1, len(bins))] + ax[1].set_title('Noisy: Scores > Limit', fontweight='bold') + + #### Flat channels #### + ch_subset = auto_scores['ch_types'] == ch_type + ch_names = auto_scores['ch_names'][ch_subset] + scores = auto_scores['scores_flat'][ch_subset] + limits = auto_scores['limits_flat'][ch_subset] + bins = auto_scores['bins'] #the windows that were evaluated + + # Label each segment by its start and stop time (3 digits / 1 ms precision) + bin_labels = [f'{start:3.3f} - {stop:3.3f}' + for start, stop in bins] + + # Store data in DataFrame + data_to_plot = pd.DataFrame(data=scores, + columns=pd.Index(bin_labels, name='Time (s)'), + index=pd.Index(ch_names, name='Channel')) + + # First, plot the raw scores + sns.heatmap(data=data_to_plot, + cmap='Reds', + cbar=False, + # cbar_kws=dict(label='Score'), + ax=ax[2]) + [ax[2].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') + for x in range(1, len(bins))] + ax[2].set_title('Flat: All Scores', fontweight='bold') + + # Second, highlight segments that exceeded the 'noisy' limit + sns.heatmap(data=data_to_plot, + vmax=np.nanmax(limits), + cmap='Reds', + cbar=True, + # cbar_kws=dict(label='Score'), + ax=ax[3]) + [ax[3].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') + for x in range(1, len(bins))] + ax[3].set_title('Flat: Scores > Limit', fontweight='bold') + + # Fit figure title to not overlap with the subplots + fig.tight_layout(rect=[0, 0.03, 1, 0.95]) + return fig + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + run_maxwell_filter(subject_id, visit_id) diff --git a/preprocessing/P02_find_bad_eeg.py b/preprocessing/P02_find_bad_eeg.py new file mode 100644 index 0000000..a5a50dd --- /dev/null +++ b/preprocessing/P02_find_bad_eeg.py @@ -0,0 +1,271 @@ +""" +=================================== +02. Find bad EEG sensors +=================================== + +EEG bad sensors are detected using a revisited version of +the PREP pipeline https://doi.org/10.3389/fninf.2015.00016 + +@author: Oscar Ferrante oscfer88@gmail.com + +""" # noqa: E501 + +import os.path as op +import os +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import zscore +import shutil + +from fpdf import FPDF + +# PS +# from mne.time_frequency import psd_multitaper +import mne_bids +from pyprep.prep_pipeline import PrepPipeline + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +def find_bad_eeg(subject_id, visit_id, record="run", has_eeg=False): + + # Check whether there are EEG data for this participant and stop if not + if not has_eeg: + raise ValueError("Error: there is no EEG recording for this participant (%s)" % subject_id) + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and record in fname: + + # Set run + if "run" in fname: + run = f"{int(fname[-10]):02}" + elif "rest" in fname: + run = None + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + elif "rest" in fname: + bids_task = "rest" + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path_sss = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=run, + session=visit_id, + suffix="sss", + extension='.fif', + check=False) + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path_sss).load_data() + + # Read EEG electrode layout + if run in ["01", None] and bids_task in ['rest', 'dur', 'replay']: + montage = raw.get_montage() + + # Plot montage (EEG layout) + fig = montage.plot(kind='topomap', show_names=False) + fname_fig = op.join(prep_figure_root, + '02_rAll_eeg_montage.png') + fig.savefig(fname_fig) + plt.close() + + # Add montage figure to the report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'EEG Montage', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw*.8) + + # Set line freq and its harmonics + line_freqs = np.arange(raw.info['line_freq'], raw.info["sfreq"] / 2, raw.info['line_freq']) + + # Set prep params + prep_params = { + "ref_chs": "eeg", + "reref_chs": "eeg", + "line_freqs": line_freqs, + "max_iterations": 4} + + # Run Prep pipeline + prep = PrepPipeline(raw, + prep_params, + montage, + ransac=True) + prep.fit() + + # Print results + print("Bad channels: {}".format(prep.interpolated_channels)) + print("Bad channels after interpolation: {}".format(prep.still_noisy_channels)) + + # Extract raw + raw_car = prep.raw + + # Interpolate bad channels left by the prep method + raw_car.interpolate_bads(reset_bads=True) + + # Mark bad channels in the raw bids folder + bids_path = mne_bids.BIDSPath( + root=bids_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=run, + session=visit_id, + extension='.fif') + + mne_bids.mark_channels(ch_names=(prep.interpolated_channels+prep.still_noisy_channels), + bids_path=bids_path, + status='bad', + verbose=False) + + # Save filtered data + bids_path_car = bids_path_sss.copy().update( + suffix="car", + check=False) + + raw_car.save(bids_path_car, overwrite=True) + + # Plot EEG data + fig = raw.copy().pick('eeg').plot(bad_color=(1., 0., 0.), + scalings = dict(eeg=10e-5), + duration=5, + start=100) + fname_fig = op.join(prep_figure_root, + '02_%sr%s_bad_egg_0raw.png' % (bids_task,run)) + fig.savefig(fname_fig) + plt.close() + + # Plot EEG power spectrum + fig1 = viz_psd(raw) + fname_fig1 = op.join(prep_figure_root, + '02_%sr%s_bad_egg_0pow.png' % (bids_task,run)) + fig1.savefig(fname_fig1) + plt.close() + + # Add figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Power Spectrum of Raw EEG Data', 'B', ln=1) + pdf.image(fname_fig1, 0, 45, pdf.epw*.8) + + # Plot re-referenced EEG data + fig = raw_car.copy().pick('eeg').plot(bad_color=(1., 0., 0.), + scalings = dict(eeg=10e-5), + duration=5, + start=100) + fname_fig = op.join(prep_figure_root, + '02_%sr%s_bad_egg_3refer.png' % (bids_task,run)) + fig.savefig(fname_fig) + plt.close() + + # Plot re-referenced EEG power spectrum + fig1 = viz_psd(raw_car) + fname_fig1 = op.join(prep_figure_root, + '02_%sr%s_bad_egg_Ipow.png' % (bids_task,run)) + fig1.savefig(fname_fig1) + plt.close() + + # Add figures to report + pdf.ln(120) + pdf.cell(0, 10, 'Power Spectrum of Interpolated/Re-referenced EEG Data', 'B', ln=1) + pdf.image(fname_fig1, 0, 175, pdf.epw*.8) + + # Add note about bad channels + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, "Bad channels:") + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Before prep: %s' % prep.interpolated_channels, 'B', ln=1) + pdf.cell(0, 10, 'After prep: %s' % prep.still_noisy_channels, 'B', ln=1) + pdf.cell(0, 10, 'After intepolation: %s' % raw_car.info['bads'], 'B', ln=1) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Save report + if record == "rest": + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report_rest.pdf')) + else: + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report.pdf')) + + +def viz_psd(raw): + # Compute averaged power + # PS + # psds, freqs = psd_multitaper(raw,fmin = 1,fmax = 40, picks=['eeg']) + psds, freqs = raw.compute_psd(method='multitaper', fmin = 1,fmax = 40, picks=['eeg']) + psds = np.sum(psds,axis = 1) + psds = 10. * np.log10(psds) + # Show power spectral density plot + fig, ax = plt.subplots(2, 1, figsize=(12, 8)) + raw.plot_psd(picks = ["eeg"], + fmin = 1,fmax = 40, + ax=ax[0]) + # Normalize (z-score) channel-specific average power values + psd = {} + psd_zscore = zscore(psds) + for i in range(len(psd_zscore)): + psd["EEG%03d"%(i+1)] = psd_zscore[i] + # Plot chennels ordered by power + ax[1].bar(sorted(psd, key=psd.get,reverse = True),sorted(psd.values(),reverse = True),width = 0.5) + labels = sorted(psd, key=psd.get,reverse = True) + ax[1].set_xticklabels(labels, rotation=90) + ax[1].annotate("Average power: %.2e dB"%(np.average(psds)),(27,np.max(psd_zscore)*0.9),fontsize = 'x-large') + return fig + +def input_bool(message): + value = input(message) + if value == "True": + return True + if value == "False": + return False + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + has_eeg = input_bool("Has this recording EEG data? (True or False)\n>>> ") + find_bad_eeg(subject_id, visit_id, has_eeg) + diff --git a/preprocessing/P03_artifact_annotation.py b/preprocessing/P03_artifact_annotation.py new file mode 100644 index 0000000..377f627 --- /dev/null +++ b/preprocessing/P03_artifact_annotation.py @@ -0,0 +1,241 @@ +""" +=========================== +03. Artifact annotation +=========================== + +Detect and note ocular and muscle artifacts + +@author: Oscar Ferrante oscfer88@gmail.com + +""" # noqa: E501 + +import os.path as op +import os +import matplotlib.pyplot as plt +import shutil + +from fpdf import FPDF +import mne +from mne.preprocessing import annotate_muscle_zscore +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +def artifact_annotation(subject_id, visit_id, record="run", has_eeg=False, threshold_muscle=7): + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + # Find whether the recording has EEG data and set the suffix + if has_eeg: + suffix = "car" + else: + suffix = 'sss' + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and record in fname: + + # Set run + if "run" in fname: + run = f"{int(fname[-10]):02}" + elif "rest" in fname: + run = None + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + elif "rest" in fname: + bids_task = "rest" + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path_sss = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=run, + session=visit_id, + suffix=suffix, + extension='.fif', + check=False) + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path_sss).load_data() + + ########################### + # Detect ocular artifacts # + ########################### + + if has_eeg: + # Resetting the EOG channel + eog_ch = raw.copy().pick_types(meg=False, eeg=False, eog=True) + if len(eog_ch.ch_names) < 2: + raw.set_channel_types({'BIO002':'eog'}) + raw.rename_channels({'BIO002': 'EOG002'}) + + # Find EOG events + eog_events = mne.preprocessing.find_eog_events(raw) + onsets = (eog_events[:, 0] - raw.first_samp) / raw.info['sfreq'] - 0.25 + durations = [0.5] * len(eog_events) + descriptions = ['Blink'] * len(eog_events) + + # Annotate events + annot_blink = mne.Annotations( + onsets, + durations, + descriptions) + + # Plot blink with EEG data + eeg_picks = mne.pick_types(raw.info, + meg=False, + eeg=True, + eog=True) + fig = raw.plot(events=eog_events, + start=100, + order=eeg_picks) + fname_fig = op.join(prep_figure_root, + "03_%sr%s_artifact_blink.png" % (bids_task,run)) + fig.savefig(fname_fig) + plt.close() + + ########################### + # Detect muscle artifacts # + ########################### + + # Notch filter + raw_muscle = raw.copy().notch_filter([50, 100]) + + # Choose one channel type, if there are axial gradiometers and magnetometers, + # select magnetometers as they are more sensitive to muscle activity. + annot_muscle, scores_muscle = annotate_muscle_zscore( + raw_muscle, + ch_type="mag", + threshold=threshold_muscle, + min_length_good=0.3, + filter_freq=[110, 140]) + + # Plot muscle z-scores across recording + fig1, ax = plt.subplots() + ax.plot(raw.times, scores_muscle) + ax.axhline(y=threshold_muscle, color='r') + ax.set(xlabel='time, (s)', ylabel='zscore', title='Muscle activity (threshold = %s)' % threshold_muscle) + fname_fig1 = op.join(prep_figure_root, + "03_%sr%s_artifact_muscle.png" % (bids_task,run)) + fig1.savefig(fname_fig1) + plt.close() + + # Add figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Muscle artifact power', 'B', ln=1) + pdf.image(fname_fig1, 0, 45, pdf.epw*.8) + + ################# + # Detect breaks # + ################# + + if record == "run": + # Get events + events, event_id = mne.events_from_annotations(raw) + + # Detect breaks based on events + annot_break = mne.preprocessing.annotate_break( + raw=raw, + events=events, + min_break_duration=15.0) + + ########################### + + # Contatenate blink and muscle artifact annotations + if has_eeg: + annot_artifact = annot_blink + annot_muscle + else: + annot_artifact = annot_muscle + annot_artifact = mne.Annotations(onset = annot_artifact.onset + raw._first_time, + duration = annot_artifact.duration, + description = annot_artifact.description, + orig_time = raw.info['meas_date']) + + # Add artifact annotations in raw + if record == "run": + raw.set_annotations(raw.annotations + annot_artifact + annot_break) + elif record == "rest": + raw.set_annotations(raw.annotations + annot_artifact) + + # View raw with annotations + channel_picks = mne.pick_types(raw.info, + meg='mag', eog=True) + fig2 = raw.plot(duration=50, + start=100, + order=channel_picks) + fname_fig2 = op.join(prep_figure_root, + "03_%sr%s_artifact_annot.png" % (bids_task,run)) + fig2.savefig(fname_fig2) + plt.close() + + # Add figures to report + pdf.ln(120) + pdf.cell(0, 10, 'Data and annotations', 'B', ln=1) + pdf.image(fname_fig2, 0, 175, pdf.epw) + + # Save data with annotated artifacts + bids_path_annot = bids_path_sss.copy().update( + suffix="annot", + check=False) + + raw.save(bids_path_annot, overwrite=True) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Save report + if record == "rest": + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report_rest.pdf')) + else: + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report.pdf')) + +def input_bool(message): + value = input(message) + if value == "True": + return True + if value == "False": + return False + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + has_eeg = input_bool("Has this recording EEG data? (True or False)\n>>> ") + threshold_muscle = int(input("Set the threshold for muscle artifact? (default is 7)\n>>> ")) + artifact_annotation(subject_id, visit_id, has_eeg=has_eeg, threshold_muscle=threshold_muscle) diff --git a/preprocessing/P04_extract_events.py b/preprocessing/P04_extract_events.py new file mode 100644 index 0000000..bb1abc0 --- /dev/null +++ b/preprocessing/P04_extract_events.py @@ -0,0 +1,278 @@ +""" +=================== +04. Extract events +=================== + +Extract events from the stimulus channel + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os.path as op +import os +import numpy as np +import pandas as pd +from fpdf import FPDF +import shutil + +import mne +import matplotlib.pyplot as plt +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +def run_events(subject_id, visit_id): + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + print("Processing subject: %s" % subject_id) + + # Lopp over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and "run" in fname: + + # Set run + run = int(fname[-10]) + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path_annot = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=f"{run:02}", + session=visit_id, + suffix='annot', + extension='.fif', + check=False) + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path_annot) + + ############### + # Read events # + ############### + + # Find response events + response = mne.find_events(raw, + stim_channel='STI101', + consecutive = False, + mask = 65280, + mask_type = 'not_and' + ) + response = response[response[:,2] == 255] + + # Find all other events + events = mne.find_events(raw, + stim_channel='STI101', + consecutive = True, + min_duration=0.001001, + mask = 65280, + mask_type = 'not_and' + ) + events = events[events[:,2] != 255] + + # Concatenate all events + events = np.concatenate([response,events],axis = 0) + events = events[events[:,0].argsort(),:] + + # Show events + fig = mne.viz.plot_events(events) + fname_fig = op.join(prep_figure_root, + "04_%sr%s_events.png" % (bids_task,run)) + fig.savefig(fname_fig) + plt.close(fig) + + # Add figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Events', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw) + + # Save event array + bids_path_eve = bids_path_annot.copy().update( + suffix="eve", + check=False) + if not op.exists(bids_path_eve): + bids_path_eve.fpath.parent.mkdir(exist_ok=True, parents=True) + + mne.write_events(bids_path_eve.fpath, events) + + ################# + # Read metadata # + ################# + + # # Generate metadata table + if visit_id == 'V1': + eve = events.copy() + events = eve[eve[:, 2] < 81].copy() + metadata = {} + metadata = pd.DataFrame(metadata, index=np.arange(len(events)), + columns=['Stim_trigger', 'Category', + 'Orientation', 'Duration', + 'Task_relevance', 'Trial_ID', + 'Response', 'Response_time(s)']) + Category = ['face', 'object', 'letter', 'false'] + Orientation = ['Center', 'Left', 'Right'] + Duration = ['500ms', '1000ms', '1500ms'] + Relevance = ['Relevant target', 'Relevant non-target', 'Irrelevant'] + k = 0 + for i in range(eve.shape[0]): + if eve[i, 2] < 81: + ##find the end of each trial (trigger 97) + t = [t for t, j in enumerate(eve[i:i + 9, 2]) if j == 97][0] + metadata.loc[k]['Stim_trigger'] = eve[i,2] + metadata.loc[k]['Category'] = Category[int((eve[i,2]-1)//20)] + metadata.loc[k]['Orientation'] = Orientation[[j-100 for j in eve[i:i+t,2] + if j in [101,102,103]][0]-1] + metadata.loc[k]['Duration'] = Duration[[j-150 for j in eve[i:i+t,2] + if j in [151,152,153]][0]-1] + metadata.loc[k]['Task_relevance'] = Relevance[[j-200 for j in eve[i:i+t,2] + if j in [201,202,203]][0]-1] + metadata.loc[k]['Trial_ID'] = [j for j in eve[i:i+t,2] + if (j>110) and (j<149)][0] + metadata.loc[k]['Response'] = True if any(eve[i:i+t,2] == 255) else False + if metadata.loc[k]['Response'] == True: + r = [r for r,j in enumerate(eve[i:i+t,2]) if j == 255][0] + metadata.loc[k]['Response_time(s)'] = (eve[i+r,0] - eve[i,0]) + # miniblock = [j for j in eve[i:i+t,2] if (j>160) and (j<201)] + # metadata.loc[k]['Miniblock_ID'] = miniblock[0] if miniblock != [] else np.nan + k += 1 + + elif visit_id == 'V2': + if bids_task == "vg": + eve = events.copy() + metadata = {} + metadata = pd.DataFrame(metadata, index=np.arange(np.sum(events[:, 2] < 51)), + columns=['Trial_type', + 'Stim_trigger', + 'Stimuli_type', + 'Location', + 'Response', + 'Response_time']) + types0 = ['Filler', 'Probe'] + type1 = ['Face', 'Object', 'Blank'] + location = ['Upper Left', 'Upper Right', 'Lower Right', 'Lower Left'] + response = ['Seen', 'Unseen'] + k = 0 + for i in range(eve.shape[0]): + if eve[i, 2] < 51: + metadata.loc[k]['Stim_trigger'] = eve[i, 2] + t = int(eve[i + 1, 2] % 10) + metadata.loc[k]['Trial_type'] = types0[t] + if eve[i, 2] == 50: + metadata.loc[k]['Stimuli_type'] = type1[2] + else: + metadata.loc[k]['Stimuli_type'] = type1[eve[i, 2] // 20] + metadata.loc[k]['Location'] = location[eve[i + 1, 2] // 10 - 6] + if t == 1: + metadata.loc[k]['Response'] = response[int(eve[i + 4, 2] - 98)] + metadata.loc[k]['Response_time(s)'] = (eve[i + 4, 0] - eve[i + 3, 0]) #/ sfreq + k += 1 + elif bids_task == "replay": + eve = events.copy() + metadata = {} + metadata = pd.DataFrame(metadata, + index=np.arange(np.size( + [i for i in events[:, 2] if i in list(range(101,151)) + list(range(201,251))])), + columns=['Stim_trigger', + 'Stimuli_type', + 'Trial_type', + 'Location', + 'Response', + 'Response_time']) + types0 = ['Non-Target', 'Target'] + type1 = ['Face', 'Object', 'Black'] + # type1 = ['Face Target', 'Object Non-Target', 'Blank during Face Target', + # 'Object Target', 'Face Non-Target', 'Blank during Object Target'] + location = ['Upper Left', 'Upper Right', 'Lower Right', 'Lower Left'] + response = ['Seen', 'Unseen'] + k = 0 + for i in range(eve.shape[0]): + if eve[i, 2] in list(range(101,151)) + list(range(201,251)): + metadata.loc[k]['Stim_trigger'] = eve[i, 2] + # t = int(eve[i + 1, 2] % 10) + if eve[i, 2] in range(101,111): + metadata.loc[k]['Stimuli_type'] = type1[0] + metadata.loc[k]['Trial_type'] = types0[1] + elif eve[i, 2] in range(121,131): + metadata.loc[k]['Stimuli_type'] = type1[1] + metadata.loc[k]['Trial_type'] = types0[0] + elif eve[i, 2] == 150: + metadata.loc[k]['Stimuli_type'] = type1[2] + elif eve[i, 2] in range(221,231): + metadata.loc[k]['Stimuli_type'] = type1[1] + metadata.loc[k]['Trial_type'] = types0[1] + elif eve[i, 2] in range(201,211): + metadata.loc[k]['Stimuli_type'] = type1[0] + metadata.loc[k]['Trial_type'] = types0[0] + elif eve[i, 2] == 250: + metadata.loc[k]['Stimuli_type'] = type1[2] + metadata.loc[k]['Location'] = location[eve[i + 1, 2] // 10 - 6] + if metadata.loc[k]['Trial_type'] == 'Target': + if 198 in eve[i:i + 4, 2]: + print(eve[i:i + 4, 2]) + metadata.loc[k]['Response'] = response[0] + metadata.loc[k]['Response_time'] = (eve[i + 4, 0] - eve[i + 3, 0]) #/ sfreq + else: + metadata.loc[k]['Response'] = response[1] + k += 1 + + # Save metadata table as csv + bids_path_meta = bids_path_annot.copy().update( + suffix="meta", + extension='.csv', + check=False) + if not op.exists(bids_path_meta): + bids_path_meta.fpath.parent.mkdir(exist_ok=True, parents=True) + + metadata.to_csv(bids_path_meta.fpath, + index=False) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Save report + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report.pdf')) + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + run_events(subject_id, visit_id) + \ No newline at end of file diff --git a/preprocessing/P05_run_ica.py b/preprocessing/P05_run_ica.py new file mode 100644 index 0000000..3b1f202 --- /dev/null +++ b/preprocessing/P05_run_ica.py @@ -0,0 +1,265 @@ +""" +=========== +05. Run ICA +=========== + +Run indipendent component analysis. + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os.path as op +import os +import matplotlib.pyplot as plt +import shutil + +from fpdf import FPDF +import mne +from mne.preprocessing import ICA +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +def run_ica(subject_id, visit_id, has_eeg=False): + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and "run" in fname: + + # Set run + run = int(fname[-10]) + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path_annot = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=f"{run:02}", + session=visit_id, + suffix='annot', + extension='.fif', + check=False) + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path_annot).load_data() + raw.info['bads'] = [] + + # Band-pass filter raw between 1 and 40 Hz + raw.filter(1, 40) + + # Downsample raw to 200 Hz + raw.resample(200) + + # Concatenate raw copies + if run == 1: + raw_all = mne.io.concatenate_raws([raw]) + else: + raw_all = mne.io.concatenate_raws([raw_all, raw]) + + del raw + + ################### + # ICA on MEG data # + ################### + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Define ICA settings + ica = ICA(method='fastica', + random_state=1688, + n_components=0.99, + verbose=True) + + # Run ICA on filtered raw data + ica.fit(raw_all, + picks='meg', + reject_by_annotation=True, + verbose=True) + + # Plot timecourse and topography of the ICs + # before, get the total number of ICs and divide them into n sets of 20 + n_comp_list = range(ica.n_components_) + plot_comp_list = [n_comp_list[i:i + 20] for i in range(0, len(n_comp_list), 20)] + + for i in range(len(plot_comp_list)): + + # Plot timecourse + fig = ica.plot_sources(raw_all, + picks=plot_comp_list[i], + start=100, + show_scrollbars=False, + title='ICA_MEG') + fname_fig = op.join(prep_figure_root, + "05_rAll_ica_meg_src%d.png" % i) + fig.savefig(fname_fig) + plt.close(fig) + + # Add timecourse figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8] + ' - MEG') + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Timecourse of MEG ICs', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw) + + # Plot topography + fig = ica.plot_components(title='ICA_MEG', + picks=plot_comp_list[i]) + fname_fig = op.join(prep_figure_root, + '05_rAll_ica_meg_cmp%d.png' % i) + fig.savefig(fname_fig) + plt.close(fig) + + # Add topography figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8] + ' - MEG') + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Topography of MEG ICs', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw) + + # Save ICA file + bids_path_ica = bids_path_annot.copy().update( + task=None, + run=None, + suffix="meg_ica", + check=False) + if not op.exists(bids_path_ica): + bids_path_ica.fpath.parent.mkdir(exist_ok=True, parents=True) + + ica.save(bids_path_ica) + + # Save report + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + 'MEG-report.pdf')) + + ################### + # ICA on EEG data # + ################### + + if has_eeg: + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Define ICA settings + ica = ICA(method='fastica', + random_state=1688, + n_components=0.99, + verbose=True) + + # Run ICA on filtered raw data + ica.fit(raw_all, + picks='eeg', + verbose=True) + + # Plot timecourse and topography of the ICs + # Get the total number of ICs and divide them into sets of 20 ICs + n_comp_list = range(ica.n_components_) + plot_comp_list = [n_comp_list[i:i + 20] for i in range(0, len(n_comp_list), 20)] + + for i in range(len(plot_comp_list)): + # Plot timecourse + fig = ica.plot_sources(raw_all, + picks=plot_comp_list[i], + start=100, + show_scrollbars=False, + title='ICA_EEG') + + fname_fig = op.join(prep_figure_root, + "05_rAll_ica_eeg_src%d.png" % i) + fig.savefig(fname_fig) + plt.close(fig) + + # Add timecourse figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:16] + ' - EEG') + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Timecourse of EEG ICs', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw) + + # Plot topography + fig = ica.plot_components(title='ICA_EEG', + picks=plot_comp_list[i]) + fname_fig = op.join(prep_figure_root, + '05_rAll_ica_eeg_cmp%d.png' % i) + fig.savefig(fname_fig) + plt.close(fig) + + # Add topography figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:16] + ' - EEG') + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Topography of EEG ICs', 'B', ln=1) + pdf.image(fname_fig, 0, 45, pdf.epw) + + # Save ICA file + bids_path_ica = bids_path_annot.copy().update( + task=None, + run=None, + suffix="eeg_ica", + check=False) + if not op.exists(bids_path_ica): + bids_path_ica.fpath.parent.mkdir(exist_ok=True, parents=True) + + ica.save(bids_path_ica) + + # Save report + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + 'EEG-report.pdf')) + # Save code + shutil.copy(__file__, prep_code_root) + +def input_bool(message): + value = input(message) + if value == "True": + return True + if value == "False": + return False + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + has_eeg = input_bool("Has this recording EEG data? (True or False)\n>>> ") + run_ica(subject_id, visit_id, has_eeg) + \ No newline at end of file diff --git a/preprocessing/P06_apply_ica.py b/preprocessing/P06_apply_ica.py new file mode 100644 index 0000000..4f791f8 --- /dev/null +++ b/preprocessing/P06_apply_ica.py @@ -0,0 +1,231 @@ +""" +=============== +06. Apply ICA +=============== + +This relies on the ICAs computed in P05-run_ica.py + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os.path as op +import os +import matplotlib.pyplot as plt +import shutil +import json + +from fpdf import FPDF +from mne.preprocessing import read_ica +import mne_bids + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + + +def apply_ica(subject_id, visit_id, record="run", has_eeg=False): + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + # Read what component to reject from the JSON file + with open(op.join(prep_deriv_root, 'P05_rej_comp.json'), 'r') as openfile: + rej_comp_json = json.load(openfile) + + # # Compute mean (and sd) number of rejected components + # ica = [] + # for key, value in rej_comp_json.items(): + # ica.append(value['V1']['meg_ica_eog'] + value['V1']['meg_ica_ecg']) + # count = [] + # for i in ica: + # count.append(len(i)) + # import numpy as np + # print(np.mean(count)) + # print(np.std(count)) + + meg_ica_eog = rej_comp_json[subject_id][visit_id].get('meg_ica_eog') + meg_ica_ecg = rej_comp_json[subject_id][visit_id].get('meg_ica_ecg') + eeg_ica_eog = rej_comp_json[subject_id][visit_id].get('eeg_ica_eog') + eeg_ica_ecg = rej_comp_json[subject_id][visit_id].get('eeg_ica_ecg') + + if meg_ica_eog + meg_ica_ecg != []: + # Read ICA mixing matrices + bids_path_meg_ica = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + session=visit_id, + suffix='meg_ica', + extension='.fif', + check=False) + + ica_meg = read_ica(bids_path_meg_ica.fpath) + + # Select EOG- and ECG-related components for exclusion + ica_meg.exclude.extend(meg_ica_eog + meg_ica_ecg) + + if eeg_ica_eog + eeg_ica_ecg != []: + # Read ICA mixing matrices + bids_path_eeg_ica = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + session=visit_id, + suffix='eeg_ica', + extension='.fif', + check=False) + + ica_eeg = read_ica(bids_path_eeg_ica.fpath) + + # Select EOG- and ECG-related components for exclusion + ica_eeg.exclude.extend(eeg_ica_eog + eeg_ica_ecg) + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and record in fname: + + # Set run + if "run" in fname: + run = f"{int(fname[-10]):02}" + elif "rest" in fname: + run = None + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + elif "rest" in fname: + bids_task = "rest" + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path_annot = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=run, + session=visit_id, + suffix='annot', + extension='.fif', + check=False) + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path_annot).load_data() + + # Fix EOG001 channel name (required for SA only) + if 'EOG004' in raw.ch_names: + raw.rename_channels({'EOG004': 'EOG001'}) + + # Show original signal + if has_eeg: + chs = ['MEG0311', 'MEG0121', 'MEG1211', 'MEG1411', 'EEG001','EEG002', 'EOG001','EOG002'] + else: + chs = ['MEG0311', 'MEG0121', 'MEG1211', 'MEG1411', 'EOG001','EOG002'] + chan_idxs = [raw.ch_names.index(ch) for ch in chs] + fig1 = raw.plot(order=chan_idxs, + duration=20, + start=100) + fname_fig1 = op.join(prep_figure_root, + '06_%sr%s_ica_raw0.png' % (bids_task,run)) + fig1.savefig(fname_fig1) + plt.close() + + # Add figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:-8]) + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Timecourse of input data', 'B', ln=1) + pdf.image(fname_fig1, 0, 45, pdf.epw) + + # Remove component from MEG signal + if meg_ica_eog + meg_ica_ecg != []: + ica_meg.apply(raw) + + # Remove component from EEG signal + if eeg_ica_eog + eeg_ica_ecg != []: + ica_eeg.apply(raw) + + # Save filtered data + bids_path_filt = bids_path_annot.copy().update( + root=prep_deriv_root, + suffix="filt", + check=False) + + raw.save(bids_path_filt, overwrite=True) + + # Show cleaned signal + fig_ica = raw.plot(order=chan_idxs, + duration=20, + start=100) + fname_fig_ica = op.join(prep_figure_root, + '06_%sr%s_ica_rawICA.png' % (bids_task,run)) + fig_ica.savefig(fname_fig_ica) + plt.close() + + # Add figures to report + pdf.ln(120) + pdf.cell(0, 10, 'Timecourse of output data', 'B', ln=1) + pdf.image(fname_fig_ica, 0, 175, pdf.epw) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Add note about removed ICs to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, "Excluded indipendent components:") + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'MEG eog: %s' % meg_ica_eog, 'B', ln=1) + pdf.cell(0, 10, 'MEG ecg: %s' % meg_ica_ecg, 'B', ln=1) + pdf.ln(20) + pdf.cell(0, 10, 'EEG eog: %s' % eeg_ica_eog, 'B', ln=1) + pdf.cell(0, 10, 'EEG ecg: %s' % eeg_ica_ecg, 'B', ln=1) + + # Save report + if record == "rest": + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report_rest.pdf')) + else: + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + '-report.pdf')) + +def input_bool(message): + value = input(message) + if value == "True": + return True + if value == "False": + return False + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + has_eeg = input_bool("Has this recording EEG data? (True or False)\n>>> ") + apply_ica(subject_id, visit_id, has_eeg) + \ No newline at end of file diff --git a/preprocessing/P07_make_epochs.py b/preprocessing/P07_make_epochs.py new file mode 100644 index 0000000..6ed7993 --- /dev/null +++ b/preprocessing/P07_make_epochs.py @@ -0,0 +1,288 @@ +""" +==================== +07. Make epochs +==================== + +Open questions: + - separate MEG and EEG in two different FIF files? + - Exp.2: separate VG and replay in two different files? + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os.path as op +import os +import pandas as pd +import matplotlib.pyplot as plt +import shutil + +from fpdf import FPDF +import mne +import mne_bids +from autoreject import get_rejection_threshold + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import (bids_root, tmin, tmax) + + +def run_epochs(subject_id, visit_id, task, has_eeg=False): + + # Prepare PDF report + pdf = FPDF(orientation="P", unit="mm", format="A4") + + # Set path to preprocessing derivatives + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + prep_figure_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "figures") + prep_report_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "reports") + prep_code_root = op.join(prep_deriv_root, + f"sub-{subject_id}",f"ses-{visit_id}","meg", + "codes") + + # Create empty lists + raw_list = list() + events_list = list() + metadata_list = list() + + print("Processing subject: %s" % subject_id) + + # Loop over runs + data_path = os.path.join(bids_root,f"sub-{subject_id}",f"ses-{visit_id}","meg") + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".json") and task in fname: + if "run" in fname or "rest" in fname: + # Set run + if "run" in fname: + run = f"{int(fname[-10]):02}" + elif "rest" in fname: + run = None + print(" Run: %s" % run) + + # Read filtered data + bids_path_filt = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=task, + run=run, + session=visit_id, + suffix='filt', + extension='.fif', + check=False) + + raw_tmp = mne_bids.read_raw_bids(bids_path_filt) + + # Read events + if "run" in fname: + bids_path_eve = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=task, + run=run, + session=visit_id, + suffix='eve', + extension='.fif', + check=False) + + events_tmp = mne.read_events(bids_path_eve.fpath) + + # Read metadata + bids_path_meta = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=task, + run=run, + session=visit_id, + suffix='meta', + extension='.csv', + check=False) + + metadata_tmp = pd.read_csv(op.join(bids_path_meta.fpath)) + + metadata_list.append(metadata_tmp) + + elif "rest" == task: + events_tmp = mne.make_fixed_length_events( + raw_tmp, duration=5) + + # Append read data to list + raw_list.append(raw_tmp) + events_list.append(events_tmp) + + + # Concatenate raw instances as if they were continuous + raw, events = mne.concatenate_raws(raw_list, + events_list=events_list) + del raw_list + + # Concatenate metadata tables and save it + if task != "rest": + metadata = pd.concat(metadata_list) + + bids_path_meta.update( + run=None, + check=False) + + metadata.to_csv(bids_path_meta.fpath, + index=False) + + # Set trial-onset event_ids + if task == "rest": + events_id = {"rest": 1} + elif visit_id == 'V1': + events_id = {} + types = ['face','object','letter','false'] + for j,t in enumerate(types): + for i in range(1,21): + events_id[t+str(i)] = i + j * 20 + elif visit_id == 'V2': + if task == "vg": + events_id = {} + events_id['blank'] = 50 + types = ['face','object'] + for j,t in enumerate(types): + for i in range(1,11): + events_id[t+str(i)] = i + j * 20 + elif task == "replay": + events_id = {} + events_id['blankFT'] = 150 + events_id['blankOT'] = 250 + typesF = ['faceT','faceNT'] + for j,t in enumerate(typesF): + for i in range(1,11): + events_id[t+str(i)] = i + 100 + j * 100 + typesO = ['objectNT','objectT'] + for j,t in enumerate(typesO): + for i in range(1,11): + events_id[t+str(i)] = i + 120 + j * 100 + + # Select sensor types + picks = mne.pick_types(raw.info, + meg = True, + eeg = has_eeg, + stim = True, + eog = has_eeg, + ecg = has_eeg) + + # Epoch raw data + epochs = mne.Epochs(raw, + events, + events_id, + tmin, tmax, + baseline=None, + proj=True, + picks=picks, + detrend=1, + reject=None, + reject_by_annotation=True, #reject muscle artifacts + verbose=True) + del raw + + # Add metadata + if task != "rest": + epochs.metadata = metadata + + # Get rejection thresholds + reject = get_rejection_threshold(epochs, + ch_types=['mag', 'grad'], #'eeg'], #eeg not used for epoch rejection + decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + nr_epo_prerej = len(epochs.events) + epochs.drop_bad(reject=reject) + nr_epo_postrej = len(epochs.events) + + # Plot percentage of rejected epochs per channel + fig1 = epochs.plot_drop_log() + fname_fig1 = op.join(prep_figure_root, + f'07_{task}rAll_epoch_drop.png') + fig1.savefig(fname_fig1) + plt.close() + + # Add figure to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:16]) + pdf.ln(120) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, 'Percentage of rejected epochs', 'B', ln=1) + pdf.image(fname_fig1, 0, 45, pdf.epw*.8) + pdf.ln(20) + pdf.cell(0, 10, "Number of epochs:") + pdf.ln(20) + pdf.set_font('helvetica', 'B', 12) + pdf.cell(0, 10, f'Before rejection: {nr_epo_prerej}', 'B', ln=1) + pdf.cell(0, 10, f'After rejection: {nr_epo_postrej}', 'B', ln=1) + + # Plot evoked by epochs + fig2 = epochs.plot(picks='meg', + title='meg', + n_epochs=10) + fname_fig2 = op.join(prep_figure_root, + f'07_{task}rAll_epoch_evk.png') + fig2.savefig(fname_fig2) + plt.close(fig2) + + # Add figures to report + pdf.add_page() + pdf.set_font('helvetica', 'B', 16) + pdf.cell(0, 10, fname[:16]) + pdf.ln(20) + pdf.cell(0, 10, 'Epoched data', 'B', ln=1) + pdf.image(fname_fig2, 0, 45, pdf.epw) + + # Count the number of epochs defined by different events + num = {} + for key in events_id: + num[key] = len(epochs[key]) + df = pd.DataFrame(num, + index = ["Total"]) + df.to_csv(op.join(prep_report_root, + f'P07_make_epochs-count_{task}_event.csv'), + index=False) + print(df) + + # Save epoched data + bids_path_epo = bids_path_filt.copy().update( + root=prep_deriv_root, + run=None, + suffix="epo", + check=False) + + epochs.save(bids_path_epo, overwrite=True) + + # Save code + shutil.copy(__file__, prep_code_root) + + # Save report + if task == "rest": + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + f'-{task}-report_rest.pdf')) + else: + pdf.output(op.join(prep_report_root, + os.path.basename(__file__) + f'_{task}-report.pdf')) + + +def input_bool(message): + value = input(message) + if value == "True": + return True + if value == "False": + return False + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + task = input("Type the task (dur, vg or replay)\n>>> ") + has_eeg = input_bool("Has this recording EEG data? (True or False)\n>>> ") + run_epochs(subject_id, visit_id, task, has_eeg) + \ No newline at end of file diff --git a/preprocessing/P99_run_preproc.py b/preprocessing/P99_run_preproc.py new file mode 100644 index 0000000..177e359 --- /dev/null +++ b/preprocessing/P99_run_preproc.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 5 16:56:07 2021 + +@author: Oscar Ferrante oscfer88@gmail.com +""" + + +import argparse + +import P01_maxwell_filtering +import P02_find_bad_eeg +import P03_artifact_annotation +import P04_extract_events +import P05_run_ica +import P06_apply_ica +import P07_make_epochs + + +# ============================================================================= +# PARSER SETTINGS +# ============================================================================= + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit',type=str,default='V1',help='visit_id') +parser.add_argument('--record',type=str,default='run',help='recording_type (run or rest') +parser.add_argument('--step',type=str,default='1',help='preprocess step') + +opt=parser.parse_args() + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + +subject_id = opt.sub +visit_id = opt.visit +record = opt.record + +# Find out whwether the participant has EEG data +if visit_id.upper() == 'V1': + if subject_id.upper() in ['SA101', 'SA102', 'SA103', 'SA104', 'SB036']: + has_eeg = False + else: + has_eeg = True +elif visit_id.upper() == 'V2': + if subject_id.upper() in ['SA104', 'SA106', 'SA125', 'SB036']: + has_eeg = False + else: + has_eeg = True + + +# ============================================================================= +# DEFINE PREPROCESSING STEPS +# ============================================================================= + +def pre_step1(): + print("\n\n\n#######################\nP01_maxwell_filtering\n#######################\n") + P01_maxwell_filtering.run_maxwell_filter(subject_id, + visit_id, + record) + if has_eeg: + print("\n\n\n#######################\nP02_find_bad_eeg\n#######################\n") + P02_find_bad_eeg.find_bad_eeg(subject_id, + visit_id, + record, + has_eeg) + print("\n\n\n#######################\nP03_artifact_annotation\n#######################\n") + P03_artifact_annotation.artifact_annotation(subject_id, + visit_id, + record, + has_eeg, + # threshold_muscle, + ) + if record == "run": + print("\n\n\n#######################\nP04_extract_events\n#######################\n") + P04_extract_events.run_events(subject_id, + visit_id) + print("\n\n\n#######################\nP05_run_ica\n#######################\n") + P05_run_ica.run_ica(subject_id, + visit_id, + has_eeg) + +def pre_step2(): + print("\n\n\n#######################\nP06_apply_ica\n#######################\n") + P06_apply_ica.apply_ica(subject_id, + visit_id, + record, + has_eeg) + + print("\n\n\n#######################\nP07_make_epochs\n#######################\n") + if record == "rest": + P07_make_epochs.run_epochs(subject_id, + visit_id, + "rest", + has_eeg) + elif visit_id == 'V1': + P07_make_epochs.run_epochs(subject_id, + visit_id, + 'dur', + has_eeg) + elif visit_id == 'V2': + P07_make_epochs.run_epochs(subject_id, + visit_id, + 'vg', + has_eeg) + P07_make_epochs.run_epochs(subject_id, + visit_id, + 'replay', + has_eeg) + + +# ============================================================================= +# RUN +# ============================================================================= +if opt.step == '1': + pre_step1() +elif opt.step == '2': + pre_step2() +elif opt.step == '0': + pre_step1() + pre_step2() diff --git a/qc/P00_bids_conversion.py b/qc/P00_bids_conversion.py new file mode 100644 index 0000000..a8d8dc1 --- /dev/null +++ b/qc/P00_bids_conversion.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +""" +# ==================== +# 00. BIDS convertion +# ==================== + +https://mne.tools/mne-bids/stable/index.html + +Questions/Issues: + - what to write in the participants and dataset_description metadata + files? + - what session ID should we give to the anat scan (e.g., v0, v2, mri)? + - for visit 2, what to count the replay runs? Continue the count from + where it was left from the VG (run 4) or restart from run 1? + +Notes: + - the conversion must be done after reading the events. Here, the event + list includes all the triggers/events + - the T1 scan can be added under 'anat' together with the transformation + matrix obtained from the coregistraction. T1s can be defaced at this stage. + - participant info must be updated manually (e.g., age, sex) + - datatype can be 'meg', 'eeg', or a few options (no meeg). Moreover, + fif files are automatically read as 'meg' datatype and this cannot + be overwritten by the datatype option. Concurrent MEG-EEG data type is MEG + - For the anat conversion, you need to run FreeSurfer and complete the + co-registration step first. + - BIDS does not allow DICOM scans. NIfTI conversion is required. + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import argparse + +import mne +from mne_bids import (write_raw_bids, write_meg_calibration, + write_meg_crosstalk, BIDSPath, write_anat) +# from mne_bids import (print_dir_tree, make_report, write_anat) +# from mne_bids.write import get_anat_landmarks +# from mne_bids.stats import count_events + +# import dicom2nifti # conda install -c conda-forge dicom2nifti + +# from config import (subject_list, file_exts, raw_path, cal_path, t1_path, +# bids_root) + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA101', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--in_raw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/Raw/projects/CoG_MEG_PhaseII', + help='Path to the RAW data directory') +# RAW: /mnt/beegfs/XNAT/COGITATE/MEG/Raw/projects/CoG_MEG_PhaseII +# MEG: /mnt/beegfs/XNAT/COGITATE/MEG/Raw/projects/CoG_MEG_PhaseII/SA101/SA101_MEEG_V1/SCANS/DurR1/FIF/SA101_MEEG_V1_DurR1.fif +# T1: /mnt/beegfs/XNAT/COGITATE/MEG/Raw/projects/CoG_MEG_PhaseII/SA101/SA101_MR_V0/SCANS/5/DICOM/xxx.dcm +parser.add_argument('--in_cal', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/preprocessing/cal_files', + help='Path to the fine-calibration and cross-talk files') +parser.add_argument('--out_bids', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', + help='Path to the BIDS root directory') +opt=parser.parse_args() + + +def init_par(opt): #subject_id, visit_id, t1_path, file_exts, bids_root + # Prepare BIDS fields + bids = {} + bids["root"] = opt.out_bids + bids["subject"] = opt.sub + bids["site"] = bids["subject"][0:2] + bids["session"] = opt.visit + bids["datatype"] = 'meg' + + # Set MEG data path and name + if bids["session"] == 'V1': + file_exts = ['%s_MEEG_V1_DurR1', + '%s_MEEG_V1_DurR2', + '%s_MEEG_V1_DurR3', + '%s_MEEG_V1_DurR4', + '%s_MEEG_V1_DurR5'] + elif bids["session"] == 'V2': + file_exts = ['%s_MEEG_V2_VGR1', + '%s_MEEG_V2_VGR2', + '%s_MEEG_V2_VGR3', + '%s_MEEG_V2_VGR4', + '%s_MEEG_V2_ReplayR1', + '%s_MEEG_V2_ReplayR2'] + meg = {} + meg["subject"] = bids["subject"] + meg["fnames"] = [f % bids["subject"] for f in file_exts] + meg["data_path"] = op.join(opt.in_raw, bids["subject"], + bids["subject"]+"_MEEG_"+bids["session"], + "SCANS", + "%s", #task+run (e.g., DurR1) + "FIF", + ) + meg["cal_path"] = op.join(opt.in_cal, bids["site"]) + + # Set anat MRI path and names + t1 ={} + t1["subject"] = bids["subject"] + t1["fname"] = bids["subject"] + '_MR_V0_anat' + t1["nifti_path"] = op.join(bids["root"], + "derivatives", + "dicom2nifti", + bids["subject"], + ) #TODO + t1["dicom_path"] = op.join(opt.in_raw, bids["subject"], + bids["subject"]+"_MR_V0", + "SCANS", + "%s", #task+run (e.g., DurR1) + "5", #TODO: what is this folder? + "DICOM", + ) + t1["fs_path"] = op.join(t1["nifti_path"], "fs") #TODO + t1["mgz_path"] = op.join(t1["fs_path"], bids["subject"], 'mri', 'T1.mgz') #TODO + t1["trans_path"] = op.join(t1["fs_path"], bids["subject"]+"-trans.fif") #TODO + + return bids, meg, t1 + +def raw_to_bids(bids, meg): + for file_name in meg["fnames"]: + run = file_name[-1] + + # Set task + if 'Dur' in file_name: + bids["task"] = 'Dur' + elif 'VG' in file_name: + bids["task"] = 'VG' + elif 'Replay' in file_name: + bids["task"] = 'Replay' + else: + raise ValueError("Error: could not find the task for %s" % file_name) + + # Read raw + raw_fname = op.join(meg["data_path"] % (bids["task"]+'R'+run), file_name + '.fif') + raw = mne.io.read_raw_fif(raw_fname, allow_maxshield=True) + + # Read events + # events_data = op.join(meg["out_path"], + # meg["fnames"][0]+'-bids_eve.fif') + + events = mne.find_events(raw, + stim_channel='STI101', + consecutive = True, + min_duration=0.001001, + mask = 65280, + mask_type = 'not_and', + verbose=True) + + # Set event IDs + if bids["session"] == 'V1': + # Stimulus type and image ID + stimulus_id = {} + types = ['face','object','letter','false'] + for j,t in enumerate(types): + for i in range(1,21): + stimulus_id[t+f'{i:02d}'] = i + j * 20 + # Trial number (block-wise) + trial_id = {} + for i in range(111,149): + trial_id['trial'+f'{i-110:02d}'] = i + # Sequence number + sequence_id = {} + for i in range(161,201): + sequence_id['sequence'+f'{i-160:02d}'] = i + # Other events + other_id = {'onset of recording':81, 'offset of recording':83, + 'start experiment': 86, 'stimulus offset':96, + 'blank offset':97, + 'center': 101, 'left':102, 'right':103, + '500ms': 151, '1000ms': 152, '1500ms': 153, + 'task relevant target': 201, + 'task relevant non target': 202, + 'task irrelevant': 203, 'response': 255} + # Merge all event IDs in event_id + event_id = stimulus_id | trial_id | sequence_id | other_id + + elif bids["session"] == "V2": + # Stimulus type and image ID during the game + stimulus_id = {} + stimulus_id['blank'] = 50 + types = ['face','object'] + for j,t in enumerate(types): + for i in range(1,11): + stimulus_id[t+f'{i:02d}'] = i + j * 20 + # Stimulus type and image ID during the replay + stimulus_id['blanks during face target'] = 150 + stimulus_id['blanks during object target'] = 250 + types = ['face target','object non-target'] + for j,t in enumerate(types): + for i in range(1,11): + stimulus_id[t+f'{i:02d}'] = (i + j * 20) + 100 + types = ['face non-target','object target'] + for j,t in enumerate(types): + for i in range(1,11): + stimulus_id[t+f'{i:02d}'] = (i + j * 20) + 200 + # Stimulus location + location_id = {} + location_id['upper left'] = 60 + location_id['upper left probed or target'] = 61 + location_id['upper right'] = 70 + location_id['upper right probed or target'] = 71 + location_id['lower right'] = 80 + location_id['lower right probed or target'] = 81 + location_id['lower left'] = 90 + location_id['lower left probed or target'] = 91 + # Response + response_id = {} + response_id['seen'] = 98 + response_id['unseen'] = 99 + response_id['response during replay'] = 198 + response_id['end replay response window'] = 196 + # Other events + other_id = {'probe onset': 100, 'filler':95, + 'filler during replay':195, + 'level begin': 251, 'level end':252, + 'animation peak end':253} + + # Merge all event IDs in event_id + event_id = stimulus_id | location_id | response_id | other_id + + # Set BIDS path + bids_path = BIDSPath(subject=bids["subject"], + session=bids["session"], + task=bids["task"].lower(), + run='0'+run, + datatype=bids["datatype"], + root=bids["root"]) + # Write BIDS + write_raw_bids(raw, + bids_path=bids_path, + events_data=events, + event_id=event_id, + overwrite=True) + + return raw + +def rest_to_bids(bids, meg): + # Add resting state data #TODO: declare that it is 5-min eyes open RS + rs_raw_fname = op.join(meg["data_path"] % "RestinEO", bids["subject"] + '_MEEG_' + bids["session"] + '_RestinEO.fif') + rs_raw = mne.io.read_raw_fif(rs_raw_fname, allow_maxshield=True) + + # Write to bids + rs_bids_path = BIDSPath(subject=bids["subject"], + session=bids["session"], + task='rest', + datatype=bids["datatype"], + root=bids["root"]) + write_raw_bids(rs_raw, rs_bids_path, overwrite=True) + +def empty_to_bids(bids, meg): + # Add empty room data + er_raw_fname = op.join(meg["data_path"] % "Rnoise", bids["subject"] + '_MEEG_' + bids["session"] + '_Rnoise.fif') + er_raw = mne.io.read_raw_fif(er_raw_fname, allow_maxshield=True) + + # For empty room data we need to specify the recording date + er_date = er_raw.info['meas_date'].strftime('%Y%m%d') + print(er_date) + + # Write to bids + er_bids_path = BIDSPath(subject=bids['site']+'emptyroom', + session=er_date, + task='noise', + datatype=bids["datatype"], + root=bids["root"]) + write_raw_bids(er_raw, er_bids_path, overwrite=True) + +def maxfiles_to_bids(bids, meg): + # Find fine-calibration and crosstalk files + cal_fname = op.join(meg["cal_path"], 'sss_cal_' + bids["site"] + '.dat') + ct_fname = op.join(meg["cal_path"], 'ct_sparse_' + bids["site"] + '.fif') + + # Set BIDS path + bids_path = BIDSPath(subject=bids["subject"], + session=bids["session"], + task=bids["task"], + run=1, + datatype=bids["datatype"], + root=bids["root"]) + + # Add files to the bids structure + write_meg_calibration(cal_fname, bids_path) + write_meg_crosstalk(ct_fname, bids_path) + +# def dicom_to_nifti(t1): +# # Convert dicom to nifti +# dicom2nifti.convert_directory(t1["dicom_path"], +# t1["nifti_path"], +# compression=True, +# reorient=True) + +# # Rename nifti.gz file +# for file in os.listdir(t1["nifti_path"]): +# if file.endswith(".gz"): +# os.rename(op.join(t1["nifti_path"], file), +# op.join(t1["nifti_path"], t1["fname"] + '.nii.gz')) + +# def t1_to_bids(t1, bids, raw): +# # Load the transformation matrix and show what it looks like +# if op.exists(t1["trans_path"]): +# trans = mne.read_trans(t1["trans_path"]) +# else: +# raise FileNotFoundError("No such file or directory: " + t1["trans_path"]) + +# # Use trans to transform landmarks to the voxel space of the T1 +# t1["nifti"] = op.join(t1["nifti_path"], t1["fname"] + '.nii.gz') +# landmarks = get_anat_landmarks( +# t1["mgz_path"], +# info=raw.info, +# trans=trans, +# fs_subject=t1["subject"], +# fs_subjects_dir=t1["fs_path"]) + +# # Create the BIDSPath object. +# t1w_bids_path = BIDSPath(subject=t1["subject"], +# session="V0", +# root=bids["root"], +# suffix='T1w') + +# # We use the write_anat function +# t1w_bids_path = write_anat( +# image = t1["nifti"], # path to the MRI scan +# bids_path = t1w_bids_path, +# landmarks=landmarks, # the landmarks in MRI voxel space +# deface=True, +# overwrite=True) + + +# ============================================================================= +# RUN +# ============================================================================= +if __name__ == '__main__': + # First, convert visit 1 MEG data + # visit_id = "V1" + bids, meg, t1 = init_par(opt) + # Convert raw (task) data to BIDS + raw = raw_to_bids(bids, meg) + # Add resting-state data + rest_to_bids(bids, meg) + # Add empty room data + empty_to_bids(bids, meg) + # Add fine-calibration and crosstalk files (maxfilter files) + maxfiles_to_bids(bids, meg) + print("\n#######################################" + +"\nBIDS conversion completed successfully!" + +"\n#######################################") + + # # Then, convert visit 2 MEG data + # # visit_id = "V2" + # bids, meg, t1 = init_par(subject_id, visit_id, t1_path, file_exts, bids_root) + # # Convert raw (task) data to BIDS + # raw = raw_to_bids(bids, meg) + # # Add resting-state data + # rest_to_bids(bids, meg) + # # Add empty room data + # empty_to_bids(bids, meg) + # # Add fine-calibration and crosstalk files (maxfilter files) + # maxfiles_to_bids(bids, meg) + + # # Eventually, convert T1 anat data + # # Convert DICOM to NIFTI + # dicom_to_nifti(t1) + # # Add T1 anatomical scan + # t1_to_bids(t1, bids, raw) + + # # Show BIDS tree + # print_dir_tree(bids_root, max_depth=4) + + # # Show report + # print(make_report(bids_root)) + + # # Count events + # count_events(bids_root) \ No newline at end of file diff --git a/qc/P00_run_qc.py b/qc/P00_run_qc.py new file mode 100644 index 0000000..ffd8939 --- /dev/null +++ b/qc/P00_run_qc.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +@author: Urszula Górska gorska@wisc.edu +""" + +import argparse + +import QC_processing + +# ============================================================================= +# PARSER SETTINGS +# ============================================================================= + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', type=str, default='SA101', help='subject_id') +parser.add_argument('--visit', type=str, default='V1', help='visit_id') + +opt=parser.parse_args() + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + +subject_id = opt.sub +visit_id = opt.visit + +# Find out whether the participant has EEG data +if visit_id.upper() == 'V1': + if subject_id.upper() in ['SA101', 'SA102', 'SA103', 'SA104']: + has_eeg = False + else: + has_eeg = True +elif visit_id.upper() == 'V2': + if subject_id.upper() in ['SA104', 'SA106']: + has_eeg = False + else: + has_eeg = True + + +# # ============================================================================= +# # DEFINE PREPROCESSING STEPS +# # ============================================================================= + +# def pre_step1(): +# P01_maxwell_filtering.run_maxwell_filter(subject_id, +# visit_id) +# if has_eeg: +# P02_find_bad_eeg.find_bad_eeg(subject_id, +# visit_id, +# has_eeg) +# P03_artifact_annotation.artifact_annotation(subject_id, +# visit_id, +# has_eeg, +# # threshold_muscle, +# ) +# P04_extract_events.run_events(subject_id, +# visit_id) +# P05_run_ica.run_ica(subject_id, +# visit_id, +# has_eeg) + +# def pre_step2( +# # meg_ica_eog=opt.mICA_eog, meg_ica_ecg=opt.mICA_ecg, +# # eeg_ica_eog=opt.eICA_eog, eeg_ica_ecg=opt.eICA_ecg, +# ): +# P06_apply_ica.apply_ica(subject_id, +# visit_id, +# has_eeg) + +# P07_make_epochs.run_epochs(subject_id, +# visit_id, +# has_eeg) + + +# ============================================================================= +# RUN +# ============================================================================= +# if opt.step == '1': +# pre_step1() +# elif opt.step == '2': +# pre_step2() +QC_processing.run_qc_processing(subject_id, visit_id, has_eeg) diff --git a/qc/P00_run_qc_epochs.py b/qc/P00_run_qc_epochs.py new file mode 100644 index 0000000..703f55a --- /dev/null +++ b/qc/P00_run_qc_epochs.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +@author: Urszula Górska gorska@wisc.edu +""" + +import argparse + +import QC_epochs + +# ============================================================================= +# PARSER SETTINGS +# ============================================================================= + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', type=str, default='SA101', help='subject_id') +parser.add_argument('--visit', type=str, default='V1', help='visit_id') + +opt=parser.parse_args() + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + +subject_id = opt.sub +visit_id = opt.visit + +# Find out whether the participant has EEG data +if visit_id.upper() == 'V1': + if subject_id.upper() in ['SA101', 'SA102', 'SA103', 'SA104']: + has_eeg = False + else: + has_eeg = True +elif visit_id.upper() == 'V2': + if subject_id.upper() in ['SA104', 'SA106']: + has_eeg = False + else: + has_eeg = True + + +# # ============================================================================= +# # DEFINE PREPROCESSING STEPS +# # ============================================================================= + +# def pre_step1(): +# P01_maxwell_filtering.run_maxwell_filter(subject_id, +# visit_id) +# if has_eeg: +# P02_find_bad_eeg.find_bad_eeg(subject_id, +# visit_id, +# has_eeg) +# P03_artifact_annotation.artifact_annotation(subject_id, +# visit_id, +# has_eeg, +# # threshold_muscle, +# ) +# P04_extract_events.run_events(subject_id, +# visit_id) +# P05_run_ica.run_ica(subject_id, +# visit_id, +# has_eeg) + +# def pre_step2( +# # meg_ica_eog=opt.mICA_eog, meg_ica_ecg=opt.mICA_ecg, +# # eeg_ica_eog=opt.eICA_eog, eeg_ica_ecg=opt.eICA_ecg, +# ): +# P06_apply_ica.apply_ica(subject_id, +# visit_id, +# has_eeg) + +# P07_make_epochs.run_epochs(subject_id, +# visit_id, +# has_eeg) + + +# ============================================================================= +# RUN +# ============================================================================= +# if opt.step == '1': +# pre_step1() +# elif opt.step == '2': +# pre_step2() +QC_epochs.run_qc_epochs(subject_id, visit_id, has_eeg) diff --git a/qc/QC_epochs.py b/qc/QC_epochs.py new file mode 100644 index 0000000..cfa4571 --- /dev/null +++ b/qc/QC_epochs.py @@ -0,0 +1,231 @@ +import os +import os.path as op +import matplotlib.pyplot as plt +import pandas as pd + +import mne +import mne_bids +from pyprep.prep_pipeline import PrepPipeline +from mne.preprocessing import annotate_muscle_zscore +from autoreject import get_rejection_threshold +from numpy import arange + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from .config.config import (bids_root, tmin, tmax) + +from matplotlib.backends.backend_pdf import PdfPages + +from qc.maxwell_filtering import run_maxwell_filter +from qc.extract_events import run_events + +from qc.viz_psd import viz_psd + + +def run_qc_epochs(subject_id, visit_id, has_eeg): + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + + # Set path to qc derivatives and create the related folders + qc_output_path = op.join(bids_root, "derivatives", "qc", visit_id) + if not op.exists(qc_output_path): + os.makedirs(qc_output_path) + + print("Processing subject: %s" % subject_id) + + #raw_list = list() + #events_list = list() + #metadata_list = list() + + # Set task + if visit_id == "V1": + bids_task = 'dur' + elif visit_id == "V2": + bids_task = 'vg' + elif visit_id == "V2": + bids_task = 'replay' + else: + raise ValueError("Error: could not set the task") + + + #with PdfPages(op.join(qc_output_path, subject_id + '_' + visit_id + '_MEG_V1_epochs.pdf')) as pdf: + + #FirstPage = plt.figure(figsize=(8,1), dpi=108) + #FirstPage.clf() + #plt.axis('off') + #plt.text(0.5, 0.5, subject_id, transform=FirstPage.transFigure, size=16, ha="center") + #pdf.savefig(FirstPage) + #plt.close() + + + #raw, events = mne.concatenate_raws(raw_list, events_list=events_list) + #del raw_list + + # Concatenate metadata tables + #metadata = pd.concat(metadata_list) + # metadata.to_csv(op.join(out_path, file_name[0:14] + 'ALL-meta.csv'), index=False) + + # Select sensor types + #picks = mne.pick_types(raw.info, + # meg = True, + # eeg = has_eeg, + # stim = True, + # eog = has_eeg, + # ecg = has_eeg, + # ) + + # Set trial-onset event_ids + if visit_id == 'V1': + events_id = {} + types = ['face','object','letter','false'] + for j,t in enumerate(types): + for i in range(1,21): + events_id[t+str(i)] = i + j * 20 +# elif visit_id == 'V2': +# events_id = {} +# events_id['blank'] = 50 +# types = ['face','object'] +# for j,t in enumerate(types): +# for i in range(1,11): +# events_id[t+str(i)] = i + j * 20 + + elif visit_id == 'V2': + if bids_task == "vg": + events_id = {} + events_id['blank'] = 50 + types = ['face','object'] + for j,t in enumerate(types): + for i in range(1,11): + events_id[t+str(i)] = i + j * 20 + elif bids_task == "replay": + events_id = {} + events_id['blankFT'] = 150 + events_id['blankOT'] = 250 + typesF = ['faceT','faceNT'] + for j,t in enumerate(typesF): + for i in range(1,11): + events_id[t+str(i)] = i + 100 + j * 100 + typesO = ['objectNT','objectT'] + for j,t in enumerate(typesO): + for i in range(1,11): + events_id[t+str(i)] = i + 120 + j * 100 + + # Epoch raw data + #epochs = mne.Epochs(raw, + # events, + # events_id, + # tmin, tmax, + # baseline=None, + # proj=True, + # picks=picks, + # detrend=1, + # reject=None, + # reject_by_annotation=True, + # verbose=True) + + # Add metadata + #epochs.metadata = metadata + + # Read epoched data from preprocessed + bids_path_epo = mne_bids.BIDSPath( + root=prep_deriv_root, + subject=subject_id, + datatype='meg', + task=bids_task, + session=visit_id, + suffix='epo', + extension='.fif', + check=False) + + epochs = mne.read_epochs(bids_path_epo.fpath, preload=False) + + if visit_id == 'V1': + print("VERY_IMPORTANT :)") + print("FACES task relevant") + epochs_rel_F = epochs['Task_relevance == "Relevant non-target" and Category == "face"'] + print(epochs_rel_F) + print("FACES task irrelevant") + epochs_irr_F = epochs['Task_relevance == "Irrelevant" and Category == "face"'] + print(epochs_irr_F) + + print("OBJECTS task relevant") + epochs_rel_O = epochs['Task_relevance == "Relevant non-target" and Category == "object"'] + print(epochs_rel_O) + print("OBJECTS task irrelevant") + epochs_irr_O = epochs['Task_relevance == "Irrelevant" and Category == "object"'] + print(epochs_irr_O) + + print("LETTERS task relevant") + epochs_rel_L = epochs['Task_relevance == "Relevant non-target" and Category == "letter"'] + print(epochs_rel_L) + print("LETTERS task irrelevant") + epochs_irr_L = epochs['Task_relevance == "Irrelevant" and Category == "letter"'] + print(epochs_irr_L) + + print("FALSE FONTS task relevant") + epochs_rel_S = epochs['Task_relevance == "Relevant non-target" and Category == "false"'] + print(epochs_rel_S) + print("FALSE FONTS task irrelevant") + epochs_irr_S = epochs['Task_relevance == "Irrelevant" and Category == "false"'] + print(epochs_irr_S) + + elif visit_id == 'V2': + print("FACES probe") + epochs_rel_F = epochs['Trial_type == "Probe" and Stimuli_type == "Face"'] + print(epochs_rel_F) + print("FACES filler") + epochs_irr_F = epochs['Trial_type == "Filler" and Stimuli_type == "Face"'] + print(epochs_irr_F) + + print("OBJECTS probe") + epochs_rel_O = epochs['Trial_type == "Probe" and Stimuli_type == "Object"'] + print(epochs_rel_O) + print("OBJECTS filler") + epochs_irr_O = epochs['Trial_type == "Filler" and Stimuli_type == "Object"'] + print(epochs_irr_O) + + print("BLANKS probe") + epochs_rel_O = epochs['Trial_type == "Probe" and Stimuli_type == "Blank"'] + print(epochs_rel_O) + print("BLANKS filler") + epochs_irr_O = epochs['Trial_type == "Filler" and Stimuli_type == "Blank"'] + print(epochs_irr_O) + + #types0 = ['Filler', 'Probe'] + #type1 = ['Face', 'Object', 'Blank'] + #location = ['Upper Left', 'Upper Right', 'Lower Right', 'Lower Left'] + #response = ['Seen', 'Unseen'] + + #print(epochs[['face1', 'face2','face3','face4','face5','face6','face7','face8','face9','face10']]) + #print(epochs[['object1', 'object2','object3','object4','object5','object6','object7','object8','object9','object10']]) + #print(epochs[['letter1', 'letter2','letter3','letter4','letter5','letter6','letter7','letter8','letter9','letter10']]) + #print(epochs[['false1', 'false2','false3','false4','false5','false6','false7','false8','false9','false10']]) + + # Get rejection thresholds - MEG only + #reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad'], #'eeg'], + # decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + #epochs.drop_bad(reject=reject) + + #print("VERY_IMPORTANT EPOCHS faces after drop") + #epochs_rel_F = epochs['Task_relevance == "Relevant non-target" and Category == "face"'] + #print(epochs_rel_F) + #print(epochs[['face1', 'face2','face3','face4','face5','face6','face7','face8','face9','face10']]) + #print(epochs[['object1', 'object2','object3','object4','object5','object6','object7','object8','object9','object10']]) + #print(epochs[['letter1', 'letter2','letter3','letter4','letter5','letter6','letter7','letter8','letter9','letter10']]) + #print(epochs[['false1', 'false2','false3','false4','false5','false6','false7','false8','false9','false10']]) + + #print("VERY_IMPORTANT DROP LOG") + #print(epochs.drop_log) + + # Plot percentage of rejected epochs per channel + #fig1 = epochs.plot_drop_log() + #pdf.savefig(fig1) + plt.close() + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + run_qc_epochs(subject_id, visit_id) \ No newline at end of file diff --git a/qc/QC_processing.py b/qc/QC_processing.py new file mode 100644 index 0000000..9ac99cf --- /dev/null +++ b/qc/QC_processing.py @@ -0,0 +1,463 @@ +import os +import os.path as op +import matplotlib.pyplot as plt +import pandas as pd + +import mne +import mne_bids +from pyprep.prep_pipeline import PrepPipeline +from mne.preprocessing import annotate_muscle_zscore +from autoreject import get_rejection_threshold +from numpy import arange + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import (bids_root, tmin, tmax) + +from matplotlib.backends.backend_pdf import PdfPages + +from qc.maxwell_filtering import run_maxwell_filter +from qc.extract_events import run_events + +from qc.viz_psd import viz_psd + +run_part_1 = True +run_part_2 = True +run_part_3 = True + +def run_qc_processing(subject_id, visit_id, has_eeg): + prep_deriv_root = op.join(bids_root, "derivatives", "preprocessing") + + # Set path to qc derivatives and create the related folders + qc_output_path = op.join(bids_root, "derivatives", "qc", visit_id) + if not op.exists(qc_output_path): + os.makedirs(qc_output_path) + + print("Processing subject: %s" % subject_id) + + raw_list = list() + events_list = list() + metadata_list = list() + + with PdfPages(op.join(qc_output_path, subject_id + '_' + visit_id + '_MEG_QC.pdf')) as pdf: + #region first page + + FirstPage = plt.figure(figsize=(8,1), dpi=108) + FirstPage.clf() + plt.axis('off') + plt.text(0.5, 0.5, subject_id, transform=FirstPage.transFigure, size=16, ha="center") + pdf.savefig(FirstPage) + plt.close() + + #endregion first page + + bids_eo_path = mne_bids.BIDSPath(root=bids_root, + datatype='meg', + subject=subject_id, + session=visit_id, + task='rest', + extension='.fif' + ) + #Read raw data + raw_eo = mne_bids.read_raw_bids(bids_eo_path) + + fig_eo = plt.figure(figsize=(9,6), dpi=108) + ax1 = plt.subplot2grid((2,2), (0,0), colspan=2) + ax1.set_title('Resting EO spectra') + ax2 = plt.subplot2grid((2,2), (1,0), colspan=2) + + raw_eo.plot_psd(fmax=100, ax=[ax1, ax2], picks = ['meg']) + pdf.savefig(fig_eo) + plt.close() + + data_path = os.path.join(bids_root, f"sub-{subject_id}", f"ses-{visit_id}", "meg") + run = 0 + for fname in os.listdir(data_path): #bug - not listed in order #TODO + if fname.endswith(".fif") and "run" in fname: + run = run + 1 + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path = mne_bids.BIDSPath( + root=bids_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=f"{run:02}", + session=visit_id, + extension='.fif') + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path) + + #region PART 2a - MEG data filtering using Maxwell filters, method - defined in config + + if run_part_2: + # Find initial head position + if run == 1: + destination = raw.info['dev_head_t'] + + # # Set BIDS path + # bids_path_sss = mne_bids.BIDSPath( + # root=op.join(bids_root, "derivatives", "preprocessing"), + # subject=subject_id, + # datatype='meg', + # task=bids_task, + # run=f"{run:02}", + # session=visit_id, + # suffix="sss", + # extension='.fif', + # check=False) + + # # Read raw data + # raw_sss = mne_bids.read_raw_bids(bids_path_sss).load_data() + + raw_sss, bad_chan = run_maxwell_filter(raw, destination, bids_path.meg_crosstalk_fpath, bids_path.meg_calibration_fpath) + + fig = plt.figure(figsize=(9,6), dpi=108) + ax1 = plt.subplot2grid((3,4), (0,0), colspan=2) + ax1.set_title('EEG spectra before filtering') #TODO why it is not working? + ax2 = plt.subplot2grid((3,4), (1,0), colspan=2) + ax3 = plt.subplot2grid((3,4), (0,2), colspan=2) + ax3.set_title('MEG spectra after filtering') + ax4 = plt.subplot2grid((3,4), (1,2), colspan=2) + + # raw.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[axes[0][0], axes[1][0]]) + raw.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[ax1, ax2], show=False) + + # raw_sss.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[axes[0][1], axes[1][1]]) + raw_sss.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[ax3, ax4], show=False) + + plt.axis('on') + ax5 = plt.subplot2grid((3,4), (2,0), colspan=2) + plt.axis('off') + ax5.text(0, 0.7, 'noisy: ' + ', '.join(bad_chan['noisy'])) + ax5.text(0, 0.4, 'flat: ' + ', '.join(bad_chan['flat'])) + + pdf.savefig(fig) + plt.close() + + ########################### + # Check EEG data quality # + ########################### + + if has_eeg: + print("has_eeg: viz_psd") + + fig = viz_psd(raw_sss) + pdf.savefig(fig) + plt.close() + + line_freqs = arange(raw.info['line_freq'], raw.info["sfreq"] / 2, raw.info['line_freq']) + + prep_params = { + "ref_chs": "eeg", + "reref_chs": "eeg", + "line_freqs": line_freqs, + "max_iterations": 8} + + montage = raw.get_montage() + prep = PrepPipeline(raw_sss, prep_params, montage, ransac=True) + prep.fit() + raw_car = prep.raw + raw_car.interpolate_bads(reset_bads=True) + + + # Print results + print("Bad channels: {}".format(prep.interpolated_channels)) + print("Bad channels after interpolation: {}".format(prep.still_noisy_channels)) + + fig = plt.figure(figsize=(9,6), dpi=108) + plt.axis('on') + ax1 = plt.subplot2grid((3,4), (0,0), colspan=2) + plt.axis('off') + ax1.text(0, 0.7, 'interpolated: ' + ', '.join(prep.interpolated_channels)) + ax1.text(0, 0.4, 'remained noisy: ' + ', '.join(prep.still_noisy_channels)) + pdf.savefig(fig) + plt.close() + + + # # Mark bad channels in the raw bids folder + # mne_bids.mark_channels(ch_names=(prep.interpolated_channels+prep.still_noisy_channels), + # bids_path=bids_path, + # status='bad', + # verbose=False) + # # end of mark + + fig = viz_psd(raw_car) + pdf.savefig(fig) + plt.close() + print("end - has_eeg: viz_psd") + + #endregion PART 2a - MEG data filtering using Maxwell filters, method - defined in config + + #region annotations + + ########################### + # Detect ocular artifacts # + ########################### + + if has_eeg: + # Resetting the EOG channel + eog_ch = raw_sss.copy().pick_types(meg=False, eeg=False, eog=True) + if len(eog_ch.ch_names) < 2: + raw_sss.set_channel_types({'BIO002':'eog'}) + raw_sss.rename_channels({'BIO002': 'EOG002'}) + + # Find EOG events + eog_events = mne.preprocessing.find_eog_events(raw_sss) + onsets = (eog_events[:, 0] - raw_sss.first_samp) / raw_sss.info['sfreq'] - 0.25 + durations = [0.5] * len(eog_events) + descriptions = ['Blink'] * len(eog_events) + + # Annotate events + annot_blink = mne.Annotations( + onsets, + durations, + descriptions) + + ########################### + # Detect muscle artifacts # + ########################### + threshold_muscle = 7 + + # Notch filter + raw_muscle = raw_sss.copy().notch_filter([50, 100]) + + # Choose one channel type, if there are axial gradiometers and magnetometers, + # select magnetometers as they are more sensitive to muscle activity. + annot_muscle, scores_muscle = annotate_muscle_zscore( + raw_muscle, + ch_type="mag", + threshold=threshold_muscle, + min_length_good=0.3, + filter_freq=[110, 140]) + + ################# + # Detect breaks # + ################# + + # Get events + # events, event_id = mne.events_from_annotations(raw_sss) + + # Detect breaks based on events + # annot_break = mne.preprocessing.annotate_break( + # raw=raw_sss, + # events=events, + # min_break_duration=15.0) + + ########################### + + # Contatenate blink and muscle artifact annotations + if has_eeg: + annot_artifact = annot_blink + annot_muscle + else: + annot_artifact = annot_muscle + annot_artifact = mne.Annotations(onset = annot_artifact.onset + raw_sss._first_time, + duration = annot_artifact.duration, + description = annot_artifact.description, + orig_time = raw_sss.info['meas_date']) + + # Add artifact annotations in raw_sss + # raw_sss.set_annotations(raw_sss.annotations + annot_artifact + annot_break) + raw_sss.set_annotations(raw_sss.annotations + annot_artifact) + + #endregion annotations + + events, metadata = run_events(raw_sss, visit_id) + # Show events + fig = mne.viz.plot_events(events) + pdf.savefig(fig) + plt.close() + + raw_list.append(raw_sss) + events_list.append(events) + metadata_list.append(metadata) + + if run_part_3: + raw, events = mne.concatenate_raws(raw_list, events_list=events_list) + del raw_list + + # Concatenate metadata tables + metadata = pd.concat(metadata_list) + # metadata.to_csv(op.join(out_path, file_name[0:14] + 'ALL-meta.csv'), index=False) + + # Select sensor types + picks = mne.pick_types(raw.info, + meg = True, + #eeg = has_eeg, + stim = True, + #eog = has_eeg, + #ecg = has_eeg, + ) + + # Set trial-onset event_ids + if visit_id == 'V1': + events_id = {} + types = ['face','object','letter','false'] + for j,t in enumerate(types): + for i in range(1,21): + events_id[t+str(i)] = i + j * 20 + elif visit_id == 'V2': + events_id = {} + events_id['blank'] = 50 + types = ['face','object'] + for j,t in enumerate(types): + for i in range(1,11): + events_id[t+str(i)] = i + j * 20 + + # Epoch raw data + epochs = mne.Epochs(raw, + events, + events_id, + tmin, tmax, + baseline=None, + proj=True, + picks=picks, + detrend=1, + reject=None, + reject_by_annotation=True, + verbose=True) + + # Add metadata + epochs.metadata = metadata + + # ALTERNATIVE + # Read epoched data from preprocessed + #bids_path_epo = mne_bids.BIDSPath( + # root=prep_deriv_root, + # subject=subject_id, + # datatype='meg', + # task=bids_task, + # session=visit_id, + # suffix='epo', + # extension='.fif', + # check=False) + #epochs = mne.read_epochs(bids_path_epo.fpath, preload=False) + + print("VERY_IMPORTANT EPOCHS faces") # save outputs to pdf for everyone to access + epochs_rel_F = epochs['Task_relevance == "Relevant non-target" and Category == "face"'] + print(epochs_rel_F) + #print(epochs[['face1', 'face2','face3','face4','face5','face6','face7','face8','face9','face10']]) + + # Get rejection thresholds - MEG only + reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad'], #'eeg'], #TODO: eeg not use for epoch rejection + decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + epochs.drop_bad(reject=reject) + + print("VERY_IMPORTANT AFTER DROP :)") + print("FACES task relevant") + epochs_rel_F = epochs['Task_relevance == "Relevant non-target" and Category == "face"'] + print(epochs_rel_F) + print("FACES task irrelevant") + epochs_irr_F = epochs['Task_relevance == "Irrelevant" and Category == "face"'] + print(epochs_irr_F) + + print("OBJECTS task relevant") + epochs_rel_O = epochs['Task_relevance == "Relevant non-target" and Category == "object"'] + print(epochs_rel_O) + print("OBJECTS task irrelevant") + epochs_irr_O = epochs['Task_relevance == "Irrelevant" and Category == "object"'] + print(epochs_irr_O) + + print("LETTERS task relevant") + epochs_rel_L = epochs['Task_relevance == "Relevant non-target" and Category == "letter"'] + print(epochs_rel_L) + print("LETTERS task irrelevant") + epochs_irr_L = epochs['Task_relevance == "Irrelevant" and Category == "letter"'] + print(epochs_irr_L) + + print("FALSE FONTS task relevant") + epochs_rel_S = epochs['Task_relevance == "Relevant non-target" and Category == "false"'] + print(epochs_rel_S) + print("FALSE FONTS task irrelevant") + epochs_irr_S = epochs['Task_relevance == "Irrelevant" and Category == "false"'] + print(epochs_irr_S) + + + print("VERY_IMPORTANT DROP LOG") + print(epochs.drop_log) + + # Plot percentage of rejected epochs per channel + fig1 = epochs.plot_drop_log() + pdf.savefig(fig1) + plt.close() + + # Epoch raw data 2/3 + #epochs = mne.Epochs(raw, + # events, + # events_id, + # tmin, tmax, + # baseline=None, + # proj=True, + # picks=picks, + # detrend=1, + # reject=None, + # reject_by_annotation=True, + # verbose=True) + + # Add metadata + #epochs.metadata = metadata + + # Get rejection thresholds - EEG only + #reject = get_rejection_threshold(epochs, ch_types=['eeg'],decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + #epochs.drop_bad(reject=reject) + + #print("VERY_IMPORTANT DROP LOG") + #print(epochs.drop_log) + + # Plot percentage of rejected epochs per channel + #fig1 = epochs.plot_drop_log() + #pdf.savefig(fig1) + #plt.close() + + + # Epoch raw data 3/3 + #epochs = mne.Epochs(raw, + # events, + # events_id, + # tmin, tmax, + # baseline=None, + # proj=True, + # picks=picks, + # detrend=1, + # reject=None, + # reject_by_annotation=True, + # verbose=True) + + #del raw + + # Add metadata + #epochs.metadata = metadata + + # Get rejection thresholds - all + #reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad', 'eeg'], + # decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + #epochs.drop_bad(reject=reject) + + # Plot percentage of rejected epochs per channel + #fig1 = epochs.plot_drop_log() + #pdf.savefig(fig1) + #plt.close() + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + run_qc_processing(subject_id, visit_id) \ No newline at end of file diff --git a/qc/QC_processing_eeg.py b/qc/QC_processing_eeg.py new file mode 100644 index 0000000..d1c1e64 --- /dev/null +++ b/qc/QC_processing_eeg.py @@ -0,0 +1,331 @@ +import os +import os.path as op +import matplotlib.pyplot as plt +import pandas as pd + +import mne +import mne_bids +from pyprep.prep_pipeline import PrepPipeline +from mne.preprocessing import annotate_muscle_zscore +from autoreject import get_rejection_threshold +from numpy import arange + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import (bids_root, tmin, tmax) + +from matplotlib.backends.backend_pdf import PdfPages + +from qc.maxwell_filtering import run_maxwell_filter +from qc.extract_events import run_events + +from qc.viz_psd import viz_psd + +run_part_1 = True +run_part_2 = True +run_part_3 = False + +def run_qc_processing(subject_id, visit_id, has_eeg): + # Set path to qc derivatives and create the related folders + prep_deriv_root = op.join(bids_root, "derivatives", "qc", visit_id) + if not op.exists(prep_deriv_root): + os.makedirs(prep_deriv_root) + + print("Processing subject: %s" % subject_id) + + raw_list = list() + events_list = list() + metadata_list = list() + + with PdfPages(op.join(prep_deriv_root, subject_id + '_' + visit_id + '_MEG_QC.pdf')) as pdf: + #region first page + + FirstPage = plt.figure(figsize=(8,1), dpi=108) + FirstPage.clf() + plt.axis('off') + plt.text(0.5, 0.5, subject_id, transform=FirstPage.transFigure, size=16, ha="center") + pdf.savefig(FirstPage) + plt.close() + + #endregion first page + + bids_eo_path = mne_bids.BIDSPath(root=bids_root, + datatype='meg', + subject=subject_id, + session=visit_id, + task='rest', + extension='.fif' + ) + # Read raw data + raw_eo = mne_bids.read_raw_bids(bids_eo_path) + + fig_eo = plt.figure(figsize=(9,6), dpi=108) + ax1 = plt.subplot2grid((2,2), (0,0), colspan=2) + ax1.set_title('Resting EO spectra') + ax2 = plt.subplot2grid((2,2), (1,0), colspan=2) + + raw_eo.plot_psd(fmax=100, ax=[ax1, ax2], picks = ['meg']) + pdf.savefig(fig_eo) + plt.close() + + data_path = os.path.join(bids_root, f"sub-{subject_id}", f"ses-{visit_id}", "meg") + run = 0 + for fname in os.listdir(data_path): + if fname.endswith(".fif") and "run" in fname: + run = run + 1 + print(" Run: %s" % run) + + # Set task + if 'dur' in fname: + bids_task = 'dur' + elif 'vg' in fname: + bids_task = 'vg' + elif 'replay' in fname: + bids_task = 'replay' + else: + raise ValueError("Error: could not find the task for %s" % fname) + + # Set BIDS path + bids_path = mne_bids.BIDSPath( + root=bids_root, + subject=subject_id, + datatype='meg', + task=bids_task, + run=f"{run:02}", + session=visit_id, + extension='.fif') + + # Read raw data + raw = mne_bids.read_raw_bids(bids_path) + + #region PART 2a - MEG data filtering using Maxwell filters, method - defined in config + + if run_part_2: + # Find initial head position + if run == 1: + destination = raw.info['dev_head_t'] + + # # Set BIDS path + # bids_path_sss = mne_bids.BIDSPath( + # root=op.join(bids_root, "derivatives", "preprocessing"), + # subject=subject_id, + # datatype='meg', + # task=bids_task, + # run=f"{run:02}", + # session=visit_id, + # suffix="sss", + # extension='.fif', + # check=False) + + # # Read raw data + # raw_sss = mne_bids.read_raw_bids(bids_path_sss).load_data() + + raw_sss, bad_chan = run_maxwell_filter(raw, destination, bids_path.meg_crosstalk_fpath, bids_path.meg_calibration_fpath) + + # fig = plt.figure(figsize=(9,6), dpi=108) + # ax1 = plt.subplot2grid((3,4), (0,0), colspan=2) + # ax1.set_title('EEG spectra before filtering') #TODO why it is not working? + # ax2 = plt.subplot2grid((3,4), (1,0), colspan=2) + # ax3 = plt.subplot2grid((3,4), (0,2), colspan=2) + # ax3.set_title('MEG spectra after filtering') + # ax4 = plt.subplot2grid((3,4), (1,2), colspan=2) + + # # raw.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[axes[0][0], axes[1][0]]) + # raw.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[ax1, ax2], show=False) + + # # raw_sss.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[axes[0][1], axes[1][1]]) + # raw_sss.plot_psd(picks=['meg'], fmin=1, fmax=100, ax=[ax3, ax4], show=False) + + # plt.axis('on') + # ax5 = plt.subplot2grid((3,4), (2,0), colspan=2) + # plt.axis('off') + # ax5.text(0, 0.7, 'noisy: ' + ', '.join(bad_chan['noisy'])) + # ax5.text(0, 0.4, 'flat: ' + ', '.join(bad_chan['flat'])) + + # pdf.savefig(fig) + # plt.close() + + ########################### + # Check EEG data quality # + ########################### + + if has_eeg: + print("has_eeg: viz_psd") + + fig = viz_psd(raw_sss) + pdf.savefig(fig) + plt.close() + + line_freqs = arange(raw.info['line_freq'], raw.info["sfreq"] / 2, raw.info['line_freq']) + + prep_params = { + "ref_chs": "eeg", + "reref_chs": "eeg", + "line_freqs": line_freqs, + "max_iterations": 4} + + montage = raw.get_montage() + prep = PrepPipeline(raw_sss, prep_params, montage, ransac=True) + prep.fit() + raw_car = prep.raw + raw_car.interpolate_bads(reset_bads=True) + + fig = viz_psd(raw_car) + pdf.savefig(fig) + plt.close() + print("end - has_eeg: viz_psd") + + #endregion PART 2a - MEG data filtering using Maxwell filters, method - defined in config + + #region annotations + + ########################### + # Detect ocular artifacts # + ########################### + + if has_eeg: + # Resetting the EOG channel + eog_ch = raw_sss.copy().pick_types(meg=False, eeg=False, eog=True) + if len(eog_ch.ch_names) < 2: + raw_sss.set_channel_types({'BIO002':'eog'}) + raw_sss.rename_channels({'BIO002': 'EOG002'}) + + # Find EOG events + eog_events = mne.preprocessing.find_eog_events(raw_sss) + onsets = (eog_events[:, 0] - raw_sss.first_samp) / raw_sss.info['sfreq'] - 0.25 + durations = [0.5] * len(eog_events) + descriptions = ['Blink'] * len(eog_events) + + # Annotate events + annot_blink = mne.Annotations( + onsets, + durations, + descriptions) + + ########################### + # Detect muscle artifacts # + ########################### + threshold_muscle = 7 + + # Notch filter + raw_muscle = raw_sss.copy().notch_filter([50, 100]) + + # Choose one channel type, if there are axial gradiometers and magnetometers, + # select magnetometers as they are more sensitive to muscle activity. + annot_muscle, scores_muscle = annotate_muscle_zscore( + raw_muscle, + ch_type="mag", + threshold=threshold_muscle, + min_length_good=0.3, + filter_freq=[110, 140]) + + ################# + # Detect breaks # + ################# + + # Get events + # events, event_id = mne.events_from_annotations(raw_sss) + + # Detect breaks based on events + # annot_break = mne.preprocessing.annotate_break( + # raw=raw_sss, + # events=events, + # min_break_duration=15.0) + + ########################### + + # Contatenate blink and muscle artifact annotations + if has_eeg: + annot_artifact = annot_blink + annot_muscle + else: + annot_artifact = annot_muscle + annot_artifact = mne.Annotations(onset = annot_artifact.onset + raw_sss._first_time, + duration = annot_artifact.duration, + description = annot_artifact.description, + orig_time = raw_sss.info['meas_date']) + + # Add artifact annotations in raw_sss + # raw_sss.set_annotations(raw_sss.annotations + annot_artifact + annot_break) + raw_sss.set_annotations(raw_sss.annotations + annot_artifact) + + #endregion annotations + + events, metadata = run_events(raw_sss, visit_id) + # Show events + fig = mne.viz.plot_events(events) + pdf.savefig(fig) + plt.close() + + raw_list.append(raw_sss) + events_list.append(events) + metadata_list.append(metadata) + + if run_part_3: + raw, events = mne.concatenate_raws(raw_list, events_list=events_list) + del raw_list + + # Concatenate metadata tables + metadata = pd.concat(metadata_list) + # metadata.to_csv(op.join(out_path, file_name[0:14] + 'ALL-meta.csv'), index=False) + + # Select sensor types + picks = mne.pick_types(raw.info, + meg = True, + eeg = has_eeg, + stim = True, + eog = has_eeg, + ecg = has_eeg, + ) + + # Set trial-onset event_ids + if visit_id == 'V1': + events_id = {} + types = ['face','object','letter','false'] + for j,t in enumerate(types): + for i in range(1,21): + events_id[t+str(i)] = i + j * 20 + elif visit_id == 'V2': + events_id = {} + events_id['blank'] = 50 + types = ['face','object'] + for j,t in enumerate(types): + for i in range(1,11): + events_id[t+str(i)] = i + j * 20 + + # Epoch raw data + epochs = mne.Epochs(raw, + events, + events_id, + tmin, tmax, + baseline=None, + proj=True, + picks=picks, + detrend=1, + reject=None, + reject_by_annotation=True, + verbose=True) + + del raw + + # Add metadata + epochs.metadata = metadata + + # Get rejection thresholds + reject = get_rejection_threshold(epochs, + ch_types=['mag', 'grad'], #'eeg'], #TODO: eeg not use for epoch rejection + decim=2) + + # Drop bad epochs based on peak-to-peak magnitude + epochs.drop_bad(reject=reject) + + # Plot percentage of rejected epochs per channel + fig1 = epochs.plot_drop_log() + pdf.savefig(fig1) + plt.close() + + +if __name__ == '__main__': + subject_id = input("Type the subject ID (e.g., SA101)\n>>> ") + visit_id = input("Type the visit ID (V1 or V2)\n>>> ") + run_qc_processing(subject_id, visit_id) \ No newline at end of file diff --git a/qc/qc/extract_events.py b/qc/qc/extract_events.py new file mode 100644 index 0000000..ad68a6f --- /dev/null +++ b/qc/qc/extract_events.py @@ -0,0 +1,115 @@ +""" +Modified by Urszula Gorska (gorska@wisc.edu) +=================== +04. Extract events +=================== + +Extract events from the stimulus channel + + +""" +import numpy as np +import mne +import pandas as pd + + + +def run_events(raw, experiment_id): + + ############### + # Read events # + ############### + + # Find response events + response = mne.find_events(raw, + stim_channel='STI101', + consecutive = False, + mask = 65280, + mask_type = 'not_and' + ) + response = response[response[:,2] == 255] + + # Find all other events + events = mne.find_events(raw, + stim_channel='STI101', + consecutive = True, + min_duration=0.001001, + mask = 65280, + mask_type = 'not_and' + ) + events = events[events[:,2] != 255] + + # Concatenate all events + events = np.concatenate([response,events],axis = 0) + events = events[events[:,0].argsort(),:] + + + ################# + # Read metadata # + ################# + + # # Generate metadata table + if experiment_id == 'V1': + eve = events.copy() + events = eve[eve[:, 2] < 81].copy() + metadata = {} + metadata = pd.DataFrame(metadata, index=np.arange(len(events)), + columns=['Stim_trigger', 'Category', + 'Orientation', 'Duration', + 'Task_relevance', 'Trial_ID', + 'Response', 'Response_time(s)']) + Category = ['face', 'object', 'letter', 'false'] + Orientation = ['Center', 'Left', 'Right'] + Duration = ['500ms', '1000ms', '1500ms'] + Relevance = ['Relevant target', 'Relevant non-target', 'Irrelevant'] + k = 0 + for i in range(eve.shape[0]): + if eve[i, 2] < 81: + ##find the end of each trial (trigger 97) + t = [t for t, j in enumerate(eve[i:i + 9, 2]) if j == 97][0] + metadata.loc[k]['Stim_trigger'] = eve[i,2] + metadata.loc[k]['Category'] = Category[int((eve[i,2]-1)//20)] + metadata.loc[k]['Orientation'] = Orientation[[j-100 for j in eve[i:i+t,2] + if j in [101,102,103]][0]-1] + metadata.loc[k]['Duration'] = Duration[[j-150 for j in eve[i:i+t,2] + if j in [151,152,153]][0]-1] + metadata.loc[k]['Task_relevance'] = Relevance[[j-200 for j in eve[i:i+t,2] + if j in [201,202,203]][0]-1] + metadata.loc[k]['Trial_ID'] = [j for j in eve[i:i+t,2] + if (j>110) and (j<149)][0] + metadata.loc[k]['Response'] = True if any(eve[i:i+t,2] == 255) else False + if metadata.loc[k]['Response'] == True: + r = [r for r,j in enumerate(eve[i:i+t,2]) if j == 255][0] + metadata.loc[k]['Response_time(s)'] = (eve[i+r,0] - eve[i,0]) + k += 1 + + elif experiment_id == 'V2': + eve = events.copy() + metadata = {} + metadata = pd.DataFrame(metadata, index=np.arange(np.sum(events[:, 2] < 51)), + columns=['Trial_type', 'Stim_trigger', + 'Stimuli_type', + 'Location', 'Response', + 'Response_time(s)']) + types0 = ['Filler', 'Probe'] + type1 = ['Face', 'Object', 'Blank'] + location = ['Upper Left', 'Upper Right', 'Lower Right', 'Lower Left'] + response = ['Seen', 'Unseen'] + k = 0 + for i in range(eve.shape[0]): + if eve[i, 2] < 51: + metadata.loc[k]['Stim_trigger'] = eve[i, 2] + t = int(eve[i + 1, 2] % 10) + metadata.loc[k]['Trial_type'] = types0[t] + if eve[i, 2] == 50: + metadata.loc[k]['Stimuli_type'] = type1[2] + else: + metadata.loc[k]['Stimuli_type'] = type1[eve[i, 2] // 20] + metadata.loc[k]['Location'] = location[eve[i + 1, 2] // 10 - 6] + if t == 1: + metadata.loc[k]['Response'] = response[int(eve[i + 4, 2] - 98)] + metadata.loc[k]['Response_time(s)'] = (eve[i + 4, 0] - eve[i + 3, 0]) + k += 1 + + + return eve, metadata diff --git a/qc/qc/maxwell_filtering.py b/qc/qc/maxwell_filtering.py new file mode 100644 index 0000000..122eb69 --- /dev/null +++ b/qc/qc/maxwell_filtering.py @@ -0,0 +1,49 @@ +""" +Modified by Urszula Gorska (gorska@wisc.edu) +=================================== +01. Maxwell filter using MNE-python +=================================== + +Includes Maxwell filter function used by MEG Team preprocessing. +The data are Maxwell filtered using tSSS/SSS. +It is critical to mark bad channels before Maxwell filtering. + +""" + +import os.path as op +import os + +import mne +from mne.preprocessing import find_bad_channels_maxwell + + +def run_maxwell_filter(raw, destination, crosstalk_file, fine_cal_file): + # Detect bad channels + raw.info['bads'] = [] + raw_check = raw.copy() + auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell( + raw_check, + cross_talk=crosstalk_file, + calibration=fine_cal_file, + return_scores=True, + verbose=True) + raw.info['bads'].extend(auto_noisy_chs + auto_flat_chs) + + # Fix Elekta magnetometer coil types + raw.fix_mag_coil_types() + + # Perform tSSS/SSS and Maxwell filtering + raw_sss = mne.preprocessing.maxwell_filter( + raw, + cross_talk=crosstalk_file, + calibration=fine_cal_file, + st_duration=None, + origin='auto', + destination=destination, + coord_frame='head', # 'meg' only for empy room, comment it if using HPI + verbose=True) + + return raw_sss, { + 'noisy': auto_noisy_chs, + 'flat': auto_flat_chs + } \ No newline at end of file diff --git a/qc/qc/viz_psd.py b/qc/qc/viz_psd.py new file mode 100644 index 0000000..2a30d79 --- /dev/null +++ b/qc/qc/viz_psd.py @@ -0,0 +1,26 @@ +from mne.time_frequency import psd_multitaper +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import zscore + +def viz_psd(raw): + # Compute averaged power + psds, freqs = psd_multitaper(raw, fmin = 1, fmax = 40, picks=['eeg']) + psds = np.sum(psds, axis = 1) + psds = 10. * np.log10(psds) + # Show power spectral density plot + fig, ax = plt.subplots(2, 1, figsize=(12, 8)) + raw.plot_psd(picks = ["eeg"], + fmin = 1, fmax = 40, + ax=ax[0]) + # Normalize (z-score) channel-specific average power values + psd = {} + psd_zscore = zscore(psds) + for i in range(len(psd_zscore)): + psd["EEG%03d"%(i+1)] = psd_zscore[i] + # Plot chennels ordered by power + ax[1].bar(sorted(psd, key=psd.get, reverse = True), sorted(psd.values(), reverse = True), width = 0.5) + labels = sorted(psd, key=psd.get, reverse = True) + ax[1].set_xticklabels(labels, rotation=90) + ax[1].annotate("Average power: %.2e dB"%(np.average(psds)), (27, np.max(psd_zscore)*0.9), fontsize = 'x-large') + return fig \ No newline at end of file diff --git a/qc/srun_bids.py b/qc/srun_bids.py new file mode 100644 index 0000000..688ac98 --- /dev/null +++ b/qc/srun_bids.py @@ -0,0 +1,30 @@ +#!/bin/bash +#SBATCH --partition=xnat +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem-per-cpu=32G +#SBATCH --mail-type=BEGIN,END +#SBATCH --mail-user=gorska@wisc.edu +#SBATCH --time 1:00:00 +#SBATCH --chdir=/hpc/users/urszula.gorska/codes/MEEG/MNE-python_pipeline_v3/ + +if [ $# -ne 2 ]; + then echo "Please pass sub_prefix visit and step as command line arguments. E.g." + echo "sbatch --array=101,103,105 srun_bids.sh SA V1" + echo "Exiting." + exit 1 +fi + +sub_prefix=$1 # Prefix of the subjects we're working on e.g. SA SB etc... +visit=$2 + +set -- + +module purge +module load Anaconda3/2020.11 +source /hpc/shared/EasyBuild/apps/Anaconda3/2020.11/bin/activate +conda activate /hpc/users/urszula.gorska/.conda/envs/mne_meg01_clone + +#srun python P00_bids_conversion.py --sub ${sub_prefix}${SLURM_ARRAY_TASK_ID} --visit ${visit} + +srun python P00_bids_conversion.py --sub ${sub_prefix}`printf "%03d" $SLURM_ARRAY_TASK_ID` --visit ${visit} \ No newline at end of file diff --git a/requirements_cogitate_meg.yml b/requirements_cogitate_meg.yml new file mode 100644 index 0000000..460e173 --- /dev/null +++ b/requirements_cogitate_meg.yml @@ -0,0 +1,60 @@ +name: meg_msp1_env +channels: + - conda-forge + - defaults +dependencies: + - python=3.8 + - pip + - h5py + - hdf4 + - hdf5 + - imageio + - imageio-ffmpeg + - ipython + - jsonschema + - jupyter + - matplotlib-inline + - matplotlib-venn + - mayavi + - mne=0.24.0 + - mne-bids + - networkx + - nibabel + - nilearn + - numba + - openneuro-py + - pandas + - pandoc + - patsy + - pillow + - pingouin + - pydicom + - pyface + - pyqt + - pysurfer + - pytables + - pyvista + - pywavelets + - pyyaml + - pyprep + - qt + - scikit-image + - scikit-learn + - scipy + - seaborn + - statsmodels + - tabulate + - yaml + - pip: + - autoreject + - h5io + - matplotlib + - memory-profiler + - mne-connectivity + - mne-rsa + - multidict + - neuropythy + - nipype + - pydeface + - ptitprince + - fpdf \ No newline at end of file diff --git a/requirements_cogitate_meg_exact.yml b/requirements_cogitate_meg_exact.yml new file mode 100644 index 0000000..75fd615 --- /dev/null +++ b/requirements_cogitate_meg_exact.yml @@ -0,0 +1,373 @@ +name: null +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=1_gnu + - aiofiles=0.7.0=pyhd8ed1ab_0 + - alsa-lib=1.2.3=h516909a_0 + - anyio=3.3.4=py39hf3d152e_1 + - aom=3.2.0=h9c3ff4c_2 + - appdirs=1.4.4=pyh9f0ad1d_0 + - apptools=5.1.0=pyh44b312d_0 + - argon2-cffi=21.1.0=py39h3811e60_2 + - async_generator=1.10=py_0 + - attrs=21.2.0=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - bleach=4.1.0=pyhd8ed1ab_0 + - blosc=1.21.0=h9c3ff4c_0 + - brotli=1.0.9=h7f98852_6 + - brotli-bin=1.0.9=h7f98852_6 + - brotlipy=0.7.0=py39h3811e60_1003 + - brunsli=0.1=h9c3ff4c_0 + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f98852_0 + - c-blosc2=2.0.4=h5f21a17_1 + - ca-certificates=2022.9.24=ha878542_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - certifi=2022.9.24=pyhd8ed1ab_0 + - cffi=1.15.0=py39h4bc2ebd_0 + - cfitsio=4.0.0=h9a35b8e_0 + - cftime=1.5.1.1=py39hce5d2b2_1 + - chardet=4.0.0=py39hf3d152e_2 + - charls=2.2.0=h9c3ff4c_0 + - charset-normalizer=2.0.7=pyhd8ed1ab_0 + - click=8.0.3=py39hf3d152e_1 + - cloudpickle=2.0.0=pyhd8ed1ab_0 + - colorama=0.4.4=pyh9f0ad1d_0 + - configobj=5.0.6=py_0 + - cryptography=35.0.0=py39h95dcef6_2 + - curl=7.80.0=h2574ce0_0 + - cycler=0.11.0=pyhd8ed1ab_0 + - cython=0.29.30=py39h5a03fae_0 + - cytoolz=0.11.2=py39h3811e60_1 + - darkdetect=0.5.1=pyhd8ed1ab_0 + - dask-core=2021.11.1=pyhd8ed1ab_0 + - dbus=1.13.6=h48d8840_2 + - debugpy=1.5.1=py39he80948d_0 + - decorator=5.1.0=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - deprecated=1.2.13=pyh6c4a22f_0 + - dicom2nifti=2.3.0=pyhd3deb0d_0 + - dipy=1.4.1=py39hce5d2b2_0 + - double-conversion=3.1.5=h9c3ff4c_2 + - eigen=3.4.0=h4bd325d_0 + - entrypoints=0.3=pyhd8ed1ab_1003 + - envisage=6.0.1=pyhd8ed1ab_0 + - expat=2.4.1=h9c3ff4c_0 + - ffmpeg=4.3.2=h37c90e5_3 + - fontconfig=2.13.1=hba837de_1005 + - fonttools=4.28.1=py39h3811e60_0 + - freetype=2.10.4=h0708190_1 + - gettext=0.19.8.1=h73d1719_1008 + - giflib=5.2.1=h36c2ea0_2 + - gl2ps=1.4.2=h0708190_0 + - glew=2.1.0=h9c3ff4c_2 + - glib=2.70.1=h780b84a_0 + - glib-tools=2.70.1=h780b84a_0 + - gmp=6.2.1=h58526e2_0 + - gnutls=3.6.13=h85f3911_1 + - graphql-core=3.1.6=pyhd8ed1ab_0 + - gst-plugins-base=1.18.5=hf529b03_2 + - gstreamer=1.18.5=h9f60fe5_2 + - h11=0.12.0=pyhd8ed1ab_0 + - h2=4.1.0=py39hf3d152e_0 + - h5netcdf=1.0.0=pyhd8ed1ab_0 + - h5py=3.3.0=nompi_py39h98ba4bc_100 + - hdf4=4.2.15=h10796ff_3 + - hdf5=1.10.6=nompi_h6a2412b_1114 + - hpack=4.0.0=pyh9f0ad1d_0 + - httpcore=0.14.3=pyhd8ed1ab_0 + - httpx=0.21.1=py39hf3d152e_0 + - hyperframe=6.0.1=pyhd8ed1ab_0 + - icu=68.2=h9c3ff4c_0 + - idna=2.10=pyh9f0ad1d_0 + - imagecodecs=2021.8.26=py39h571908b_2 + - imageio=2.9.0=py_0 + - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0 + - importlib-metadata=4.8.2=py39hf3d152e_0 + - importlib_metadata=4.8.2=hd8ed1ab_0 + - importlib_resources=5.4.0=pyhd8ed1ab_0 + - ipykernel=6.5.0=py39hef51801_1 + - ipython=7.29.0=py39hef51801_2 + - ipython_genutils=0.2.0=py_1 + - ipywidgets=7.6.5=pyhd8ed1ab_0 + - jbig=2.1=h7f98852_2003 + - jedi=0.18.1=py39hf3d152e_0 + - jinja2=3.0.3=pyhd8ed1ab_0 + - joblib=1.1.0=pyhd8ed1ab_0 + - jpeg=9d=h36c2ea0_0 + - jsoncpp=1.9.4=h4bd325d_3 + - jsonschema=4.2.1=pyhd8ed1ab_0 + - jupyter=1.0.0=py39hf3d152e_7 + - jupyter_client=6.1.12=pyhd8ed1ab_0 + - jupyter_console=6.4.0=pyhd8ed1ab_1 + - jupyter_core=4.9.1=py39hf3d152e_1 + - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 + - jupyterlab_widgets=1.0.2=pyhd8ed1ab_0 + - jxrlib=1.1=h7f98852_2 + - kiwisolver=1.3.2=py39h1a9c180_1 + - krb5=1.19.2=hcc1bbae_3 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=hddcbb42_0 + - ld_impl_linux-64=2.36.1=hea4e1c9_2 + - lerc=3.0=h9c3ff4c_0 + - libaec=1.0.6=h9c3ff4c_0 + - libblas=3.9.0=12_linux64_openblas + - libbrotlicommon=1.0.9=h7f98852_6 + - libbrotlidec=1.0.9=h7f98852_6 + - libbrotlienc=1.0.9=h7f98852_6 + - libcblas=3.9.0=12_linux64_openblas + - libclang=11.1.0=default_ha53f305_1 + - libcurl=7.80.0=h2574ce0_0 + - libdeflate=1.8=h7f98852_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libevent=2.1.10=h9b69904_4 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=11.2.0=h1d223b6_11 + - libgfortran-ng=11.2.0=h69a702a_11 + - libgfortran5=11.2.0=h5c6108e_11 + - libglib=2.70.1=h174f98d_0 + - libglu=9.0.0=he1b5a44_1001 + - libgomp=11.2.0=h1d223b6_11 + - libiconv=1.16=h516909a_0 + - liblapack=3.9.0=12_linux64_openblas + - libllvm10=10.0.1=he513fc3_3 + - libllvm11=11.1.0=hf817b99_2 + - libnetcdf=4.8.1=nompi_hcd642e3_100 + - libnghttp2=1.43.0=h812cca2_1 + - libogg=1.3.4=h7f98852_1 + - libopenblas=0.3.18=pthreads_h8fe5266_0 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.37=h21135ba_2 + - libpq=13.3=hd57d9b9_3 + - libsodium=1.0.18=h36c2ea0_1 + - libssh2=1.10.0=ha56f1ee_2 + - libstdcxx-ng=11.2.0=he4da1e4_11 + - libtheora=1.1.1=h7f98852_1005 + - libtiff=4.3.0=h6f004c6_2 + - libuuid=2.32.1=h7f98852_1000 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.11.0=h9c3ff4c_3 + - libwebp-base=1.2.1=h7f98852_0 + - libxcb=1.13=h7f98852_1004 + - libxkbcommon=1.0.3=he3ba5ed_0 + - libxml2=2.9.12=h72842e0_0 + - libzip=1.8.0=h4de3113_1 + - libzlib=1.2.11=h36c2ea0_1013 + - libzopfli=1.0.3=h9c3ff4c_0 + - littleutils=0.2.2=py_0 + - llvmlite=0.36.0=py39h1bbdace_0 + - locket=0.2.0=py_2 + - loguru=0.5.3=py39hf3d152e_3 + - lz4-c=1.9.3=h9c3ff4c_1 + - lzo=2.10=h516909a_1000 + - markupsafe=2.0.1=py39h3811e60_1 + - matplotlib-inline=0.1.3=pyhd8ed1ab_0 + - matplotlib-venn=0.11.6=pyh9f0ad1d_0 + - mayavi=4.7.2=py39h71d8d94_5 + - meshio=4.4.6=pyhd8ed1ab_0 + - mffpy=0.6.3=pyhd8ed1ab_0 + - mistune=0.8.4=py39h3811e60_1005 + - mne=0.24.0=hd8ed1ab_0 + - mne-base=0.24.0=pyhd8ed1ab_0 + - mne-bids=0.8=pyhd8ed1ab_2 + - mock=4.0.3=py39hf3d152e_2 + - mpmath=1.2.1=pyhd8ed1ab_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.27=ha770c72_1 + - mysql-libs=8.0.27=hfa10184_1 + - nbclient=0.5.9=pyhd8ed1ab_0 + - nbconvert=6.3.0=py39hf3d152e_1 + - nbformat=5.1.3=pyhd8ed1ab_0 + - ncurses=6.2=h58526e2_4 + - nest-asyncio=1.5.1=pyhd8ed1ab_0 + - netcdf4=1.5.7=nompi_py39hd2e3950_101 + - nettle=3.6=he412f7d_0 + - networkx=2.6.3=pyhd8ed1ab_1 + - nibabel=3.2.1=pyhd8ed1ab_0 + - nilearn=0.8.1=pyhd8ed1ab_0 + - nomkl=1.0=h5ca1d4c_0 + - notebook=6.4.6=pyha770c72_0 + - nspr=4.32=h9c3ff4c_1 + - nss=3.72=hb5efdd6_0 + - numba=0.53.1=py39h56b8d98_1 + - numexpr=2.7.3=py39hbd72853_102 + - numpy=1.21.4=py39hdbf815f_0 + - olefile=0.46=pyh9f0ad1d_1 + - openh264=2.1.1=h780b84a_0 + - openjpeg=2.4.0=hb52868f_1 + - openneuro-py=2021.10.1=py39hf3d152e_1 + - openssl=1.1.1l=h7f98852_0 + - outdated=0.2.2=pyhd8ed1ab_0 + - packaging=21.3=pyhd8ed1ab_0 + - pandas=1.3.4=py39hde0f152_1 + - pandas-flavor=0.2.0=py_0 + - pandoc=2.16.1=h7f98852_0 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - parso=0.8.2=pyhd8ed1ab_0 + - partd=1.2.0=pyhd8ed1ab_0 + - patsy=0.5.2=pyhd8ed1ab_0 + - pcre=8.45=h9c3ff4c_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.4.0=py39ha612740_0 + - pingouin=0.5.2=pyhd8ed1ab_0 + - pip=21.3.1=pyhd8ed1ab_0 + - pooch=1.5.2=pyhd8ed1ab_0 + - proj=8.2.0=h277dcde_0 + - prometheus_client=0.12.0=pyhd8ed1ab_0 + - prompt-toolkit=3.0.22=pyha770c72_0 + - prompt_toolkit=3.0.22=hd8ed1ab_0 + - psutil=5.8.0=py39h3811e60_2 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptitprince=0.2.5=py39h3811e60_3 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pugixml=1.11.4=h9c3ff4c_0 + - pybv=0.6.0=pyhd8ed1ab_2 + - pycparser=2.21=pyhd8ed1ab_0 + - pydicom=2.2.2=pyh6c4a22f_0 + - pyface=7.3.0=pyh44b312d_1 + - pygments=2.10.0=pyhd8ed1ab_0 + - pyhamcrest=2.0.3=pyhd8ed1ab_0 + - pyopenssl=21.0.0=pyhd8ed1ab_0 + - pyparsing=3.0.6=pyhd8ed1ab_0 + - pyprep=0.4.0=pyhd8ed1ab_1 + - pyqt=5.12.3=py39hf3d152e_8 + - pyqt-impl=5.12.3=py39hde8b62d_8 + - pyqt5-sip=4.19.18=py39he80948d_8 + - pyqtchart=5.12=py39h0fcd23e_8 + - pyqtwebengine=5.12.1=py39h0fcd23e_8 + - pyrsistent=0.18.0=py39h3811e60_0 + - pysocks=1.7.1=py39hf3d152e_4 + - pysurfer=0.11.0=py_0 + - pytables=3.6.1=py39hf6dc253_3 + - python=3.9.7=hb7a2778_3_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-picard=0.7=pyh8a188c0_0 + - python_abi=3.9=2_cp39 + - pytz=2021.3=pyhd8ed1ab_0 + - pyvista=0.32.1=pyhd8ed1ab_0 + - pyvistaqt=0.5.0=pyhd8ed1ab_0 + - pywavelets=1.2.0=py39hce5d2b2_0 + - pyyaml=6.0=py39h3811e60_3 + - pyzmq=22.3.0=py39h37b5a0c_1 + - qdarkstyle=3.0.2=pyhd8ed1ab_0 + - qt=5.12.9=hda022c4_4 + - qtconsole=5.2.0=pyhd8ed1ab_0 + - qtpy=1.11.2=pyhd8ed1ab_0 + - readline=8.1=h46c0cb4_0 + - reportlab=3.5.68=py39he59360d_1 + - requests=2.25.1=pyhd3deb0d_0 + - rfc3986=1.5.0=pyhd8ed1ab_0 + - scikit-image=0.18.3=py39hde0f152_0 + - scikit-learn=1.0.1=py39h4dfa638_2 + - scipy=1.7.2=py39hee8e79c_0 + - scooby=0.5.7=pyhd8ed1ab_0 + - seaborn=0.11.2=hd8ed1ab_0 + - seaborn-base=0.11.2=pyhd8ed1ab_0 + - send2trash=1.8.0=pyhd8ed1ab_0 + - setuptools=59.1.1=py39hf3d152e_0 + - sgqlc=14.1=pyhd8ed1ab_0 + - sip=6.5.1=py39he80948d_2 + - six=1.16.0=pyh6c4a22f_0 + - snappy=1.1.8=he1b5a44_3 + - sniffio=1.2.0=py39hf3d152e_2 + - sqlite=3.36.0=h9cd32fc_2 + - statsmodels=0.13.1=py39hce5d2b2_0 + - tabulate=0.9.0=pyhd8ed1ab_1 + - tbb=2020.2=h4bd325d_4 + - tbb-devel=2020.2=h4bd325d_4 + - terminado=0.12.1=py39hf3d152e_1 + - testpath=0.5.0=pyhd8ed1ab_0 + - threadpoolctl=3.0.0=pyh8a188c0_0 + - tifffile=2021.11.2=pyhd8ed1ab_0 + - tk=8.6.11=h27826a3_1 + - toml=0.10.2=pyhd8ed1ab_0 + - toolz=0.11.2=pyhd8ed1ab_0 + - tornado=6.1=py39h3811e60_2 + - tqdm=4.62.3=pyhd8ed1ab_0 + - traitlets=5.1.1=pyhd8ed1ab_0 + - traits=6.3.2=py39h3811e60_0 + - traitsui=7.2.0=pyhd8ed1ab_0 + - typing_extensions=4.0.0=pyha770c72_0 + - tzdata=2021e=he74cb21_0 + - urllib3=1.26.7=pyhd8ed1ab_0 + - utfcpp=3.2.1=ha770c72_0 + - vtk=9.0.3=no_osmesa_py39h62d5dbf_106 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - wheel=0.37.0=pyhd8ed1ab_1 + - widgetsnbextension=3.5.2=py39hf3d152e_1 + - wrapt=1.13.3=py39h3811e60_1 + - wurlitzer=3.0.2=py39hf3d152e_1 + - x264=1!161.3030=h7f98852_1 + - x265=3.5=h4bd325d_1 + - xlrd=2.0.1=pyhd8ed1ab_3 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.7.2=h7f98852_0 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h7f98852_1 + - xorg-libxt=1.2.1=h7f98852_2 + - xorg-xextproto=7.3.0=h7f98852_1002 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.5=h516909a_1 + - yaml=0.2.5=h516909a_0 + - zeromq=4.3.4=h9c3ff4c_1 + - zfp=0.5.5=h9c3ff4c_7 + - zipp=3.6.0=pyhd8ed1ab_0 + - zlib=1.2.11=h36c2ea0_1013 + - zstd=1.5.0=ha95c52a_0 + - pip: + - aiobotocore==2.4.1 + - aiohttp==3.8.3 + - aioitertools==0.11.0 + - aiosignal==1.3.1 + - async-timeout==4.0.2 + - autoreject==0.2.2 + - botocore==1.27.59 + - ci-info==0.3.0 + - contourpy==1.0.7 + - etelemetry==0.3.0 + - filelock==3.8.0 + - fpdf2==2.4.6 + - frites==0.4.3 + - frozenlist==1.3.3 + - fsspec==2022.11.0 + - h5io==0.1.7 + - ipycanvas==0.9.1 + - ipyevents==2.0.1 + - ipyvtklink==0.2.1 + - isodate==0.6.1 + - jmespath==1.0.1 + - looseversion==1.0.2 + - lxml==4.9.1 + - matplotlib==3.7.1 + - memory-profiler==0.60.0 + - mne-connectivity==0.2 + - mne-rsa==0.4 + - multidict==6.0.2 + - neuropythy==0.12.6 + - nice==0.1.dev0 + - nipype==1.8.5 + - pimms==0.3.20 + - pint==0.20.1 + - prov==2.0.0 + - py4j==0.10.9.7 + - pydeface==2.0.2 + - pydot==1.4.2 + - rdflib==6.2.0 + - s3fs==2022.11.0 + - simplejson==3.17.6 + - spyder-kernels==1.9.4 + - xarray==0.20.1 + - yarl==1.8.1 +prefix: /hpc/users/oscar.ferrante/.conda/envs/mne_meg01 diff --git a/requirements_cogitate_meg_lmm.yml b/requirements_cogitate_meg_lmm.yml new file mode 100644 index 0000000..ac758cc --- /dev/null +++ b/requirements_cogitate_meg_lmm.yml @@ -0,0 +1,36 @@ +name: meg_msp1_env_pymer4 +channels: + - ejolly + - conda-forge + - defaults +dependencies: + - python=3.8 + - pip + - contourpy + - hdf5 + - matplotlib + - pandas + - pango + - patsy + - pillow + - pixman + - pooch + - rpy2 + - r-emmeans + - r-base + - r-lme4 + - r-lmertest + - pyqt + - pytables + - scikit-learn + - scipy + - statsmodels + - pyprep + - pip: + - ipython + - matplotlib-inline + - mne==1.3.1 + - mne-bids==0.12 + - ptitprince + - seaborn + - pymer4 \ No newline at end of file diff --git a/requirements_cogitate_meg_lmm_exact.yml b/requirements_cogitate_meg_lmm_exact.yml new file mode 100644 index 0000000..ce7b413 --- /dev/null +++ b/requirements_cogitate_meg_lmm_exact.yml @@ -0,0 +1,302 @@ +name: pymer4 +channels: + - ejolly + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - _r-mutex=1.0.1=anacondar_1 + - alsa-lib=1.2.8=h166bdaf_0 + - attr=2.5.1=h166bdaf_1 + - backports=1.0=pyhd8ed1ab_3 + - backports.zoneinfo=0.2.1=py38h0a891b7_7 + - binutils_impl_linux-64=2.40=hf600244_0 + - blosc=1.21.3=hafa529b_0 + - brotli=1.0.9=h166bdaf_8 + - brotli-bin=1.0.9=h166bdaf_8 + - brotlipy=0.7.0=py38h0a891b7_1005 + - bwidget=1.9.14=ha770c72_1 + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f98852_0 + - c-blosc2=2.8.0=hf91038e_1 + - ca-certificates=2022.12.7=ha878542_0 + - cairo=1.16.0=h35add3b_1015 + - certifi=2022.12.7=pyhd8ed1ab_0 + - cffi=1.15.1=py38h4a40e3a_3 + - charset-normalizer=3.1.0=pyhd8ed1ab_0 + - contourpy=1.0.7=py38hfbd4bf9_0 + - cryptography=40.0.2=py38h3d167d9_0 + - curl=8.0.1=h588be90_0 + - cycler=0.11.0=pyhd8ed1ab_0 + - dbus=1.13.6=h5008d03_3 + - deepdish=0.3.7=pyhd8ed1ab_0 + - expat=2.5.0=hcb278e6_1 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=hab24e00_0 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.39.3=py38h1de0b5d_0 + - freetype=2.12.1=hca18f0e_1 + - fribidi=1.0.10=h36c2ea0_0 + - gcc_impl_linux-64=12.2.0=hcc96c02_19 + - gettext=0.21.1=h27087fc_0 + - gfortran_impl_linux-64=12.2.0=h55be85b_19 + - glib=2.76.2=hfc55251_0 + - glib-tools=2.76.2=hfc55251_0 + - graphite2=1.3.13=h58526e2_1001 + - gsl=2.7=he838d99_0 + - gst-plugins-base=1.22.0=h4243ec0_2 + - gstreamer=1.22.0=h25f0c4b_2 + - gxx_impl_linux-64=12.2.0=hcc96c02_19 + - harfbuzz=6.0.0=h3ff4399_1 + - hdf5=1.14.0=nompi_hb72d44e_103 + - icu=72.1=hcb278e6_0 + - idna=3.4=pyhd8ed1ab_0 + - importlib-resources=5.12.0=pyhd8ed1ab_0 + - importlib_resources=5.12.0=pyhd8ed1ab_0 + - jinja2=3.1.2=pyhd8ed1ab_1 + - joblib=1.2.0=pyhd8ed1ab_0 + - kernel-headers_linux-64=2.6.32=he073ed8_15 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.4=py38h43d8883_1 + - krb5=1.20.1=h81ceb04_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.15=haa2dc70_1 + - ld_impl_linux-64=2.40=h41732ed_0 + - lerc=4.0.0=h27087fc_0 + - libaec=1.0.6=hcb278e6_1 + - libblas=3.9.0=16_linux64_openblas + - libbrotlicommon=1.0.9=h166bdaf_8 + - libbrotlidec=1.0.9=h166bdaf_8 + - libbrotlienc=1.0.9=h166bdaf_8 + - libcap=2.67=he9d0100_0 + - libcblas=3.9.0=16_linux64_openblas + - libclang=16.0.2=default_h83cc7fd_0 + - libclang13=16.0.2=default_hd781213_0 + - libcups=2.3.3=h36d4200_3 + - libcurl=8.0.1=h588be90_0 + - libdeflate=1.18=h0b41bf4_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libevent=2.1.10=h28343ad_4 + - libexpat=2.5.0=hcb278e6_1 + - libffi=3.4.2=h7f98852_5 + - libflac=1.4.2=h27087fc_0 + - libgcc-devel_linux-64=12.2.0=h3b97bd3_19 + - libgcc-ng=12.2.0=h65d4601_19 + - libgcrypt=1.10.1=h166bdaf_0 + - libgfortran-ng=12.2.0=h69a702a_19 + - libgfortran5=12.2.0=h337968e_19 + - libglib=2.76.2=hebfc3b9_0 + - libgomp=12.2.0=h65d4601_19 + - libgpg-error=1.46=h620e276_0 + - libiconv=1.17=h166bdaf_0 + - libjpeg-turbo=2.1.5.1=h0b41bf4_0 + - liblapack=3.9.0=16_linux64_openblas + - libllvm16=16.0.2=hbf9e925_0 + - libnghttp2=1.52.0=h61bc06f_0 + - libnsl=2.0.0=h7f98852_0 + - libogg=1.3.4=h7f98852_1 + - libopenblas=0.3.21=pthreads_h78a6416_3 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.39=h753d276_0 + - libpq=15.2=hb675445_0 + - libsanitizer=12.2.0=h46fd767_19 + - libsndfile=1.2.0=hb75c966_0 + - libsqlite=3.40.0=h753d276_1 + - libssh2=1.10.0=hf14f497_3 + - libstdcxx-devel_linux-64=12.2.0=h3b97bd3_19 + - libstdcxx-ng=12.2.0=h46fd767_19 + - libsystemd0=253=h8c4010b_1 + - libtiff=4.5.0=ha587672_6 + - libuuid=2.38.1=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libwebp-base=1.3.0=h0b41bf4_0 + - libxcb=1.13=h7f98852_1004 + - libxkbcommon=1.5.0=h79f4944_1 + - libxml2=2.10.4=hfdac1af_0 + - libzlib=1.2.13=h166bdaf_4 + - lz4-c=1.9.4=hcb278e6_0 + - lzo=2.10=h516909a_1000 + - make=4.3=hd18ef5c_1 + - markupsafe=2.1.2=py38h1de0b5d_0 + - matplotlib=3.7.1=py38h578d9bd_0 + - matplotlib-base=3.7.1=py38hd6c3c57_0 + - mpg123=1.31.3=hcb278e6_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.32=ha901b37_1 + - mysql-libs=8.0.32=hd7da12d_1 + - ncurses=6.3=h27087fc_1 + - nlopt=2.7.1=py38hca016a5_3 + - nomkl=1.0=h5ca1d4c_0 + - nspr=4.35=h27087fc_0 + - nss=3.89=he45b914_0 + - numexpr=2.8.4=py38h69a160b_100 + - openjpeg=2.5.0=hfec8fc6_2 + - openssl=3.1.0=hd590300_2 + - packaging=23.1=pyhd8ed1ab_0 + - pandas=2.0.1=py38h01efb38_0 + - pango=1.50.14=hd33c08f_0 + - patsy=0.5.3=pyhd8ed1ab_0 + - pcre2=10.40=hc3806b6_0 + - pillow=9.5.0=py38h961100d_0 + - pip=23.1.2=pyhd8ed1ab_0 + - pixman=0.40.0=h36c2ea0_0 + - platformdirs=3.5.0=pyhd8ed1ab_0 + - ply=3.11=py_1 + - pooch=1.7.0=pyha770c72_3 + - pthread-stubs=0.4=h36c2ea0_1001 + - pulseaudio-client=16.1=h5195f5e_3 + - py-cpuinfo=9.0.0=pyhd8ed1ab_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pymer4=0.8.0.9001=py38 + - pyopenssl=23.1.1=pyhd8ed1ab_0 + - pyparsing=3.0.9=pyhd8ed1ab_0 + - pyqt=5.15.7=py38ha0d8c90_3 + - pyqt5-sip=12.11.0=py38h8dc9893_3 + - pysocks=1.7.1=pyha2e5f31_6 + - pytables=3.8.0=py38hf59a973_1 + - python=3.8.16=he550d4f_1_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-tzdata=2023.3=pyhd8ed1ab_0 + - python_abi=3.8=3_cp38 + - pytz=2023.3=pyhd8ed1ab_0 + - pytz-deprecation-shim=0.1.0.post0=py38h578d9bd_3 + - qt-main=5.15.8=h5c52f38_9 + - r-base=4.2.3=h4a03800_2 + - r-boot=1.3_28.1=r42hc72bb7e_0 + - r-cli=3.6.1=r42h38f115c_0 + - r-colorspace=2.1_0=r42h133d619_0 + - r-crayon=1.5.2=r42hc72bb7e_1 + - r-ellipsis=0.3.2=r42h06615bd_1 + - r-emmeans=1.8.5=r42hc72bb7e_0 + - r-estimability=1.4.1=r42hc72bb7e_1 + - r-fansi=1.0.4=r42h133d619_0 + - r-farver=2.1.1=r42h7525677_1 + - r-ggplot2=3.4.2=r42hc72bb7e_0 + - r-glue=1.6.2=r42h06615bd_1 + - r-gtable=0.3.3=r42hc72bb7e_0 + - r-isoband=0.2.7=r42h38f115c_1 + - r-labeling=0.4.2=r42hc72bb7e_2 + - r-lattice=0.21_8=r42h133d619_0 + - r-lifecycle=1.0.3=r42hc72bb7e_1 + - r-lme4=1.1_33=r42ha503ecb_0 + - r-lmertest=3.1_3=r42hc72bb7e_1 + - r-magrittr=2.0.3=r42h06615bd_1 + - r-mass=7.3_59=r42h57805ef_0 + - r-matrix=1.5_4=r42he1ae0d6_0 + - r-mgcv=1.8_42=r42he1ae0d6_0 + - r-minqa=1.2.5=r42hb13c81a_0 + - r-munsell=0.5.0=r42hc72bb7e_1005 + - r-mvtnorm=1.1_3=r42h8da6f51_1 + - r-nlme=3.1_162=r42hac0b197_0 + - r-nloptr=2.0.3=r42hb13c81a_1 + - r-numderiv=2016.8_1.1=r42hc72bb7e_4 + - r-pillar=1.9.0=r42hc72bb7e_0 + - r-pkgconfig=2.0.3=r42hc72bb7e_2 + - r-r6=2.5.1=r42hc72bb7e_1 + - r-rcolorbrewer=1.1_3=r42h785f33e_1 + - r-rcpp=1.0.10=r42h38f115c_0 + - r-rcppeigen=0.3.3.9.3=r42h9f5de39_0 + - r-rlang=1.1.0=r42h38f115c_0 + - r-scales=1.2.1=r42hc72bb7e_1 + - r-statmod=1.5.0=r42h74f4db8_0 + - r-tibble=3.2.1=r42h133d619_1 + - r-utf8=1.2.3=r42h133d619_0 + - r-vctrs=0.6.2=r42ha503ecb_0 + - r-viridislite=0.4.1=r42hc72bb7e_1 + - r-withr=2.5.0=r42hc72bb7e_1 + - r-xtable=1.8_4=r42hc72bb7e_4 + - readline=8.2=h8228510_1 + - requests=2.29.0=pyhd8ed1ab_0 + - rpy2=3.5.11=py38r42h07e1bb6_0 + - scikit-learn=1.2.2=py38hd4b6e60_1 + - scipy=1.10.1=py38h59f1b5f_0 + - sed=4.8=he412f7d_0 + - setuptools=67.7.2=pyhd8ed1ab_0 + - simplegeneric=0.8.1=py_1 + - sip=6.7.9=py38h17151c0_0 + - six=1.16.0=pyh6c4a22f_0 + - snappy=1.1.10=h9fff704_0 + - statsmodels=0.13.5=py38h26c90d9_2 + - sysroot_linux-64=2.12=he073ed8_15 + - threadpoolctl=3.1.0=pyh8a188c0_0 + - tk=8.6.12=h27826a3_0 + - tktable=2.10=hb7b940f_3 + - toml=0.10.2=pyhd8ed1ab_0 + - tomli=2.0.1=pyhd8ed1ab_0 + - tornado=6.3=py38h1de0b5d_0 + - typing-extensions=4.5.0=hd8ed1ab_0 + - typing_extensions=4.5.0=pyha770c72_0 + - tzdata=2023c=h71feb2d_0 + - tzlocal=4.3=py38h578d9bd_0 + - unicodedata2=15.0.0=py38h0a891b7_0 + - urllib3=1.26.15=pyhd8ed1ab_0 + - wheel=0.40.0=pyhd8ed1ab_0 + - xcb-util=0.4.0=h166bdaf_0 + - xcb-util-image=0.4.0=h166bdaf_0 + - xcb-util-keysyms=0.4.0=h166bdaf_0 + - xcb-util-renderutil=0.3.9=h166bdaf_0 + - xcb-util-wm=0.4.1=h166bdaf_0 + - xkeyboard-config=2.38=h0b41bf4_0 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-libxt=1.2.1=h7f98852_2 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - zipp=3.15.0=pyhd8ed1ab_0 + - zlib=1.2.13=h166bdaf_4 + - zlib-ng=2.0.7=h0b41bf4_0 + - zstd=1.5.2=h3eb15da_6 + - pip: + - asttokens==2.2.1 + - backcall==0.2.0 + - cloudpickle==2.2.1 + - comm==0.1.3 + - cython==0.29.34 + - debugpy==1.6.7 + - decorator==5.1.1 + - executing==1.2.0 + - importlib-metadata==6.6.0 + - ipykernel==6.22.0 + - ipython==8.12.1 + - jedi==0.18.2 + - jupyter-client==8.2.0 + - jupyter-core==5.3.0 + - matplotlib-inline==0.1.6 + - mne==1.3.1 + - mne-bids==0.12 + - nest-asyncio==1.5.6 + - numpy==1.21.4 + - parso==0.8.3 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - prompt-toolkit==3.0.38 + - psutil==5.9.5 + - ptitprince==0.2.6 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pygments==2.15.1 + - pyhamcrest==2.0.4 + - pyzmq==25.0.2 + - seaborn==0.11.0 + - spyder-kernels==1.9.4 + - stack-data==0.6.2 + - tqdm==4.65.0 + - traitlets==5.9.0 + - wcwidth==0.2.6 + - wurlitzer==3.0.3 +prefix: /home/oscar.ferrante/.conda/envs/pymer4 diff --git a/roi_mvpa/D01_ROI_MVPA_Cat.py b/roi_mvpa/D01_ROI_MVPA_Cat.py new file mode 100644 index 0000000..71feb59 --- /dev/null +++ b/roi_mvpa/D01_ROI_MVPA_Cat.py @@ -0,0 +1,502 @@ + +""" +==================== +D01. Decoding for MEG on source space of ROI +Category decoding +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +feature selection methods test + +""" + +import os.path as op + +import pickle + +import matplotlib.pyplot as plt +import mne +import numpy as np + +import argparse + + +from mne.decoding import (Vectorizer, SlidingEstimator, cross_val_multiscore) +# import a linear classifier from mne.decoding +from mne.decoding import LinearModel + +from skimage.measure import block_reduce + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +#from sklearn.feature_selection import SelectPercentile, chi2 +#from sklearn.decomposition import PCA +from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score + +# from sklearn.linear_model import LogisticRegression +# from sklearn.model_selection import StratifiedKFold + + + +from scipy.ndimage import gaussian_filter1d +import matplotlib.patheffects as path_effects + + +from config import bids_root + +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + + + # Now we define a function to decoding condition for one subject + # Category_CCD, train on condition A, test on condition B +def Category_CCD(epochs_rs,stcs,conditions_C,conditions_D,select_F,n_trials,roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + #SelectPercentile(chi2,k=select_p), + #PCA(n_components=nPCA), + LinearModel(sklearn.svm.SVC( + kernel='linear'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + sliding = SlidingEstimator(clf, scoring=score_methods, n_jobs=1) + + + print(' Creating evoked datasets') + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Category'] == conditions_C[0]] = 1 # face + temp[epochs_rs.metadata['Category'] == conditions_C[1]] = 2 # object + + y = temp + X=np.array([stc.data for stc in stcs]) + + cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # Find indices of Irrelevant trials + cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + + + # # Run cross-validated decoding analyses: + # scores_a = cross_val_multiscore(sliding,X=X[cond_a], y=y[cond_a],cv=5,n_jobs=-1) + # # Run cross-validated decoding analyses: + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=-1) + + ccd=dict() + cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # Find indices of Irrelevant trials + cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + group_xa=X[cond_a] + group_ya=y[cond_a] + group_xb=X[cond_b] + group_yb=y[cond_b] + + scores_ab_per=np.zeros([100,group_xa.shape[2]]) + scores_ba_per=np.zeros([100,group_xb.shape[2]]) + for num_per in range(100): + # do the average trial + new_xa = [] + new_ya = [] + new_xb = [] + new_yb = [] + for label in range(2): + # Extract the data: + data_a = group_xa[np.where(group_ya == label+1)] + data_a = np.take(data_a, np.random.permutation(data_a.shape[0]), axis=0) + avg_xa = block_reduce(data_a, block_size=tuple([n_trials, *[1] * len(data_a.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_xa.append(avg_xa) + new_ya += [label] * avg_xa.shape[0] + + # Extract the data: + data_b = group_xb[np.where(group_yb == label+1)] + data_b = np.take(data_b, np.random.permutation(data_b.shape[0]), axis=0) + avg_xb = block_reduce(data_b, block_size=tuple([n_trials, *[1] * len(data_b.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_xb.append(avg_xb) + new_yb += [label] * avg_xb.shape[0] + + new_xa = np.concatenate((new_xa[0],new_xa[1]),axis=0) + new_ya = np.array(new_ya) + + # average temporal feature (5 point average) + new_xa=ATdata(new_xa) + + new_xb = np.concatenate((new_xb[0],new_xb[1]),axis=0) + new_yb = np.array(new_yb) + + # average temporal feature (5 point average) + new_xb=ATdata(new_xb) + + # First: train condition a (cond_a) and Test on condition b (cond_b) cross condition decoding + # Fit + sliding.fit(X=new_xa, y=new_ya) + # Test + scores_ab = sliding.score(X=new_xb, y=new_yb) + + + scores_ab_per[num_per,:]=scores_ab + + # Then: train condition b (cond_b) and Test on condition a (cond_a) cross condition decoding + # Fit + sliding.fit(X=new_xb, y=new_yb) + # Test + scores_ba = sliding.score(X=new_xa, y=new_ya) + + + scores_ba_per[num_per,:]=scores_ba + + + + + # ccd['IR'] = np.mean(scores_a, axis=0) + # ccd['RE'] = np.mean(scores_b, axis=0) + ccd['IR2RE'] = np.mean(scores_ab_per, axis=0) + ccd['RE2IR'] = np.mean(scores_ba_per, axis=0) + + + + fig, ax = plt.subplots(1) + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + for condi, Sccd in ccd.items(): + ax.plot(t, gaussian_filter1d(Sccd,sigma=4), linewidth=1, label=str(condi), path_effects=pe) + ax.axhline(0.5,color='k',linestyle='--',label='chance') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title(f'CCD_ {roi_name}') + ax.set(xlabel='Time(ms)', ylabel='decoding score') + mne.viz.tight_layout() + # Save figure + fig.savefig(fname_fig) + + return ccd + +def Category_WCD(epochs_rs,stcs, + conditions_C,conditions_D, + select_F, + n_trials, + # nPCA, + roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + #SelectPercentile(chi2,k=select_p), + #(n_components=nPCA), + LinearModel(sklearn.svm.SVC( + kernel='linear'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + sliding = SlidingEstimator(clf, scoring=score_methods, n_jobs=1) + + + print(' Creating evoked datasets') + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Category'] == conditions_C[0]] = 1 # face + temp[epochs_rs.metadata['Category'] == conditions_C[1]] = 2 # object + + y = temp + X=np.array([stc.data for stc in stcs]) + + # cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # # # Find indices of Irrelevant trials + # cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + wcd=dict() + for condi in range(2): + con_index=np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[condi])[0] + group_x=X[con_index] + group_y=y[con_index] + + scores_per=np.zeros([100,group_x.shape[2]]) + for num_per in range(100): + # do the average trial + new_x = [] + new_y = [] + for label in range(2): + # Extract the data: + data = group_x[np.where(group_y == label+1)] + data = np.take(data, np.random.permutation(data.shape[0]), axis=0) + avg_x = block_reduce(data, block_size=tuple([n_trials, *[1] * len(data.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_x.append(avg_x) + new_y += [label] * avg_x.shape[0] + + new_x = np.concatenate((new_x[0],new_x[1]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x=ATdata(new_x) + + scores= cross_val_multiscore(sliding, X=new_x, y=new_y, cv=5, n_jobs=1) + scores_per[num_per,:]=np.mean(scores, axis=0) + + wcd[conditions_D[condi]]=np.mean(scores_per, axis=0) + + + + # wcd=dict() + # scores_a= cross_val_multiscore(sliding, X=X[cond_a], y=y[cond_a], cv=5, n_jobs=1) + # wcd[conditions_D[0]]=np.mean(scores_a, axis=0) + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=1) + # wcd[conditions_D[1]] = np.mean(scores_b, axis=0) + + # pattern = dict() + # pattern['IR'] = coef_a + # pattern['RE'] = coef_b + + + fig, ax = plt.subplots(1) + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + for condi, Ti_name in wcd.items(): + ax.plot(t, gaussian_filter1d(Ti_name,sigma=4), linewidth=1, label=str(condi), path_effects=pe) + ax.axhline(0.5,color='k',linestyle='--',label='chance') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title(f'WCD_ {roi_name}') + ax.set(xlabel='Time(ms)', ylabel='decoding score') + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig) + + return wcd + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + + analysis_name='Cat' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + roi_ccd_acc = dict() + #roi_ccd_auc = dict() + roi_wcd_acc = dict() + #roi_wcd_auc = dict() + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + + "_acc_CCD" + '.png') + + + score_methods=make_scorer(accuracy_score) + ccd_acc = Category_CCD(epochs_rs, stcs, + conditions_C, conditions_D, + select_F, + n_trials, + # nPCA, + roi_name, score_methods, + fname_fig_acc) + + roi_ccd_acc[roi_name] = ccd_acc + + + ### WCD + + + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_' + roi_name + "_acc_WCD" + '.png') + + + score_methods=make_scorer(accuracy_score) + wcd_acc= Category_WCD(epochs_rs, stcs, + conditions_C, conditions_D, + select_F, + n_trials, + # nPCA, + roi_name, score_methods, + fname_fig_acc) + + roi_wcd_acc[roi_name] = wcd_acc + + + + + roi_data=dict() + roi_data['ccd_acc']=roi_ccd_acc + + roi_data['wcd_acc']=roi_wcd_acc + + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_ROIs_data_Cat" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D01_ROI_MVPA_Cat_PFC.py b/roi_mvpa/D01_ROI_MVPA_Cat_PFC.py new file mode 100644 index 0000000..abfbf33 --- /dev/null +++ b/roi_mvpa/D01_ROI_MVPA_Cat_PFC.py @@ -0,0 +1,450 @@ + +""" +==================== +D01. Decoding for MEG on source space of ROI +Category decoding +control analysis, +compare decoding performance with vs without PFC region. +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +compare the decodeing performance of postior region with or without prefrontal region + +""" +import warnings +import os.path as op +import pickle + +import matplotlib.pyplot as plt +import mne +import numpy as np +import matplotlib as mpl + +import argparse + + + + + +from skimage.measure import block_reduce + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import Vectorizer,StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +from sklearn.metrics import accuracy_score + + +# from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import StratifiedKFold + + + +from scipy.ndimage import gaussian_filter1d + +import matplotlib.patheffects as path_effects + + +#from config import no_eeg_sbj +#from config import site_id, subject_id, file_names, visit_id, data_path, out_path +# from config import l_freq, h_freq, sfreq +# from config import (bids_root, tmin, tmax) +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root, plot_param + +from D_MEG_function import set_path_ROI_MVPA, ATdata, sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA, sub_ROI_for_ROI_MVPA + +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=DeprecationWarning) + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +# get the parameters dictionary +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +plt.rc('font', size=8) # controls default text size +plt.rc('axes', labelsize=20) +plt.rc('xtick',labelsize=18) +plt.rc('ytick',labelsize=18) +plt.rc('xtick.major', width=2, size=4) +plt.rc('ytick.major', width=2, size=4) +plt.rc('legend', fontsize=18) +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} + + +mpl.rcParams.update(new_rc_params) + +# Color parameters: +cmap = "RdYlBu_r" + + +def Category_PFC(fpath_fw,rank,common_cov,sub_info,surf_label_list, + epochs_rs,conditions_C,conditions_D,conditions_T,task_info): + #get data + stcs_PFC = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[0]) + stcs_IIT = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[1]) + stcs_IITPFC = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[2]) + + + + + # setup SVM classifier + select_Fn=[30,30,60] + clf={} + for n,roi in enumerate(['PFC', 'IIT','IITPFC']): + clf[roi] = make_pipeline(Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_Fn[n]), + sklearn.svm.SVC(kernel='linear',probability=True)) #LogisticRegression(), + + # # The scorers can be either one of the predefined metric strings or a scorer + # # callable, like the one returned by make_scorer + # #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # # score methods could be AUC or Accuracy + # # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + #sliding = SlidingEstimator(clf, scoring=make_scorer(accuracy_score), n_jobs=-1) + + + print(' Creating evoked datasets') + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Category'] == conditions_C[0]] = 1 # face + temp[epochs_rs.metadata['Category'] == conditions_C[1]] = 2 # object + + times = epochs_rs.times + + y = temp + X_PFC=np.array([stc.data for stc in stcs_PFC]) + X_IIT=np.array([stc.data for stc in stcs_IIT]) + X_IITPFC=np.array([stc.data for stc in stcs_IITPFC]) + + # cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # # # Find indices of Irrelevant trials + # cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + wcd=dict() + + con_index=np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] # only analysis Irrelevant condition + group_x_PFC=X_PFC[con_index] + group_x_IIT=X_IIT[con_index] + group_x_IITPFC=X_IITPFC[con_index] + group_y=y[con_index] + + + scores_per_IIT=np.zeros([100,len(times)]) + scores_per_IITPFC=np.zeros([100,len(times)]) + # scores_per_comb=np.zeros([100,len(times)]) + # scores_per_comb_bayes=np.zeros([100,len(times)]) + for num_per in range(100): + # do the average trial + new_x_PFC = [] + new_x_IIT = [] + new_x_IITPFC = [] + new_y = [] + for label in range(2): + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + #PFC + data_PFC = group_x_PFC[np.where(group_y == label+1)] + data_PFC = np.take(data_PFC, np.random.permutation(data_PFC.shape[0]), axis=0) + avg_x_PFC = block_reduce(data_PFC, block_size=tuple([n_trials, *[1] * len(data_PFC.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_PFC.append(avg_x_PFC) + + #IIT + data_IIT = group_x_IIT[np.where(group_y == label+1)] + data_IIT = np.take(data_IIT, np.random.permutation(data_IIT.shape[0]), axis=0) + avg_x_IIT = block_reduce(data_IIT, block_size=tuple([n_trials, *[1] * len(data_IIT.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_IIT.append(avg_x_IIT) + + + #IITPFC + data_IITPFC = group_x_IITPFC[np.where(group_y == label+1)] + data_IITPFC = np.take(data_IITPFC, np.random.permutation(data_IITPFC.shape[0]), axis=0) + avg_x_IITPFC = block_reduce(data_IITPFC, block_size=tuple([n_trials, *[1] * len(data_IITPFC.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_IITPFC.append(avg_x_IITPFC) + + + # Now generating the labels and group: + new_y += [label] * avg_x_PFC.shape[0] + + new_x_PFC = np.concatenate((new_x_PFC[0],new_x_PFC[1]),axis=0) + new_x_IIT = np.concatenate((new_x_IIT[0],new_x_IIT[1]),axis=0) + new_x_IITPFC = np.concatenate((new_x_IITPFC[0],new_x_IITPFC[1]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x_PFC=ATdata(new_x_PFC) + new_x_IIT=ATdata(new_x_IIT) + new_x_IITPFC=ATdata(new_x_IITPFC) + + skf = StratifiedKFold(n_splits=5) + # Getting the indices of the test and train sets from cross folder validation: + cv_index = list(skf.split(new_x_PFC, new_y)) + + + # n_classes=2 + n_folds=5 + # initialize storage + decoding_scores_IIT = np.empty((n_folds, len(times))) + decoding_scores_IITPFC = np.empty((n_folds, len(times))) + # decoding_scores_comb = np.empty((n_folds, len(times))) + # decoding_scores_comb_bayes = np.empty((n_folds, len(times))) + # proba_IIT = np.zeros((len(new_y), n_classes, len(times)))*np.nan + # proba_PFC = np.zeros((len(new_y), n_classes, len(times)))*np.nan + + + + + for ind, train_test_ind in enumerate(cv_index): + y_train = new_y[train_test_ind[0]] + y_test = new_y[train_test_ind[1]] + for t, time in enumerate(times): + x_train_PFC = new_x_PFC[train_test_ind[0], :, t] + x_test_PFC = new_x_PFC[train_test_ind[1], :, t] + + x_train_IIT = new_x_IIT[train_test_ind[0], :, t] + x_test_IIT = new_x_IIT[train_test_ind[1], :, t] + + x_train_IITPFC = new_x_IITPFC[train_test_ind[0], :, t] + x_test_IITPFC = new_x_IITPFC[train_test_ind[1], :, t] + + # # original code w/o calibration + # # regular prediction for iit-alone + # mdl_iit = clf['iit'].fit(x_train_iit, y_train) + # mdl_gnw = clf['gnw'].fit(x_train_gnw, y_train) + + # y_pred = mdl_iit.predict(x_test_iit) + # decoding_scores_iit[ind,t] = balanced_accuracy_score(y_test, y_pred ) + + # iit+gnw feature model + mdl_IITPFC = clf['IITPFC'].fit(x_train_IITPFC, y_train) + + mdl_IIT = clf['IIT'].fit(x_train_IIT, y_train) + mdl_PFC = clf['PFC'].fit(x_train_PFC, y_train) + + # iit-only + y_pred = mdl_IIT.predict(x_test_IIT) + decoding_scores_IIT[ind,t] = accuracy_score(y_test, y_pred ) + + # iit+pfc feature model + y_pred = mdl_IITPFC.predict( x_test_IITPFC ) + decoding_scores_IITPFC[ind,t] = accuracy_score(y_test, y_pred ) + + # # for iit+pfc model, get posterior probabilities, sum them, then norm the result (softmax), and predict the label + # mdl_prob_IIT = mdl_IIT.predict_proba( x_test_IIT ) + # mdl_prob_PFC = mdl_PFC.predict_proba( x_test_PFC ) + + # # store the probabilities + # proba_IIT[train_test_ind[1], :, t] = mdl_prob_IIT + # proba_PFC[train_test_ind[1], :, t] = mdl_prob_PFC + + # psum = mdl_prob_IIT+mdl_prob_PFC + # softmx = np.exp(psum) / np.expand_dims( np.sum(np.exp(psum),1),1) + # ypred_combined = np.argmax( softmx, 1) + # decoding_scores_comb[ind,t] = accuracy_score(y_test, mdl_IIT.classes_[ ypred_combined ] ) + + # # p_post = 1/( 1 + exp(log((1-Pgnw)/Pgnw) - log(Piit/(1-Piit)) ) ) + # PIIT = mdl_prob_IIT + # PPFC = mdl_prob_PFC + # bayes_int = 1/( 1 + np.exp(np.log((1-PPFC)/PPFC) - np.log(PIIT/(1-PPFC)) ) ) + # ypred_combined = np.argmax( bayes_int, 1) + # decoding_scores_comb_bayes[ind,t] = accuracy_score(y_test, mdl_IIT.classes_[ ypred_combined ] ) + + + + + + scores_per_IIT[num_per,:]=np.mean(decoding_scores_IIT, axis=0) + scores_per_IITPFC[num_per,:]=np.mean(decoding_scores_IITPFC, axis=0) + # scores_per_comb[num_per,:]=np.mean(decoding_scores_comb, axis=0) + # scores_per_comb_bayes[num_per,:]=np.mean(decoding_scores_comb_bayes, axis=0) + + wcd['IIT']=np.mean(scores_per_IIT, axis=0) + wcd['IITPFC_f']=np.mean(scores_per_IITPFC, axis=0) # feature combine score + # wcd['IITPFC_m']=np.mean(scores_per_comb, axis=0) # model combine score + # wcd['IITPFC_m_bayes']=np.mean(scores_per_comb_bayes, axis=0) # model combine score with bayes methods + + + + + + + + return wcd + + + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + + analysis_name='Cat_PFC' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + roi_ccd_acc = dict() + #roi_ccd_auc = dict() + roi_wcd_acc = dict() + + + fname_fig = op.join(roi_figure_root,sub_info + task_info + '_' + "IITPFC_acc_WCD" + '.png') + + wcd_acc=Category_PFC(fpath_fw,rank,common_cov,sub_info,surf_label_list, + epochs_rs,conditions_C,conditions_D,conditions_T,task_info) + + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_IITPFC_data_Cat" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(wcd_acc,fw) + fw.close() + + + + fig, ax = plt.subplots(1) + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + for condi, Ti_name in wcd_acc.items(): + ax.plot(t, gaussian_filter1d(Ti_name,sigma=4), linewidth=1, label=str(condi), path_effects=pe) + ax.axhline(0.5,color='k',linestyle='--',label='chance') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title('WCD_IIT_PFC') + ax.set(xlabel='Time(ms)', ylabel='decoding score') + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig) + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D01_ROI_MVPA_Cat_subROI.py b/roi_mvpa/D01_ROI_MVPA_Cat_subROI.py new file mode 100644 index 0000000..f6c0f65 --- /dev/null +++ b/roi_mvpa/D01_ROI_MVPA_Cat_subROI.py @@ -0,0 +1,231 @@ +""" +==================== +D01. Decoding for MEG on source space of ROI, +Category decoding +control analysis, +decoding at subROI. +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +feature selection methods test + +""" +#import os +import os.path as op + +import pickle + + +import argparse + + + +from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score + + +from config import bids_root +from D_MEG_function import set_path_ROI_MVPA, sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + +from D01_ROI_MVPA_Cat import Category_WCD + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() + +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') + +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() + +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +#nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + + + + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + ROI_index = ['subF','subP'] + + for analysis_index in ROI_index: + analysis_name='Cat_' + analysis_index + '_control' + + + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + #roi_ccd_acc = dict() + #roi_ccd_auc = dict() + roi_wcd_acc = dict() + #roi_wcd_auc = dict() + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + + # fname_fig_acc = op.join(roi_figure_root, + # sub_info + task_info + '_'+ roi_name + # + "_acc_CCD" + '.png') + + + # score_methods=make_scorer(accuracy_score) + # ccd_acc = Category_CCD(epochs_rs, stcs, + # conditions_C, conditions_D, + # select_F, + # n_trials, + # # nPCA, + # roi_name, score_methods, + # fname_fig_acc) + + # roi_ccd_acc[roi_name] = ccd_acc + + + ### WCD + + + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_' + roi_name + "_acc_WCD" + '.png') + + + score_methods=make_scorer(accuracy_score) + wcd_acc= Category_WCD(epochs_rs, stcs, + conditions_C, conditions_D, + select_F, + n_trials, + # nPCA, + roi_name, score_methods, + fname_fig_acc) + + roi_wcd_acc[roi_name] = wcd_acc + + + + + roi_data=dict() + # roi_data['ccd_acc']=roi_ccd_acc + + roi_data['wcd_acc']=roi_wcd_acc + + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_ROIs_data_" + analysis_name + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D02_ROI_MVPA_Ori.py b/roi_mvpa/D02_ROI_MVPA_Ori.py new file mode 100644 index 0000000..a5ff3c5 --- /dev/null +++ b/roi_mvpa/D02_ROI_MVPA_Ori.py @@ -0,0 +1,334 @@ + +""" +==================== +D01. Decoding for MEG on source space of ROI +Orientation decoding +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +feature selection methods test + +""" + +import os +import os.path as op + +import pickle + +import matplotlib.pyplot as plt +import mne +import numpy as np + +import argparse + + +from mne.decoding import (Vectorizer, SlidingEstimator, cross_val_multiscore) +# import a linear classifier from mne.decoding +from mne.decoding import LinearModel + +from skimage.measure import block_reduce + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +#from sklearn.feature_selection import SelectPercentile, chi2 + +from sklearn.metrics import make_scorer +from sklearn.metrics import balanced_accuracy_score +# from sklearn.linear_model import LogisticRegression +# from sklearn.model_selection import StratifiedKFold + + + +from scipy.ndimage import gaussian_filter1d + +import matplotlib.patheffects as path_effects + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + + +def Orientation_WCD(epochs_rs,stcs,conditions_C,select_F,n_trials,roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + #SelectPercentile(chi2,k=select_p), + #PCA(n_components=nPCA), + LinearModel(sklearn.svm.SVC( + kernel='linear',class_weight='balanced'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": }# + # For multivariable decoding(e,g, 1,2,3, could not use roc_auc), + # deal with unbalanced trial number, score should use make_scorer(balanced_accuracy_score) + + sliding = SlidingEstimator(clf, scoring=score_methods, n_jobs=1) + + + print(' Creating evoked datasets') + + conditions_O = ['Center', 'Left', 'Right'] + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Orientation'] == conditions_O[0]] = 1 # center, straight + temp[epochs_rs.metadata['Orientation'] == conditions_O[1]] = 2 # left,side orientation + temp[epochs_rs.metadata['Orientation'] == conditions_O[2]] = 3 # right, also side orientation + + + y = temp + X=np.array([stc.data for stc in stcs]) + + # cond_a = np.where(epochs_rs.metadata['Category'] == conditions_C[0])[0] + # # Find indices of Object trials + # cond_b = np.where(epochs_rs.metadata['Category'] == conditions_C[1])[0] + + # # # Run cross-validated decoding analyses: + # scores_a = cross_val_multiscore(sliding,X=X[cond_a], y=y[cond_a],cv=5,n_jobs=1) + # # # Run cross-validated decoding analyses: + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=1) + + # # + # wcd = dict() + # # ccd['IR'] = np.mean(scores_a, axis=0) + # # ccd['RE'] = np.mean(scores_b, axis=0) + # wcd[conditions_C[0]] = np.mean(scores_a, axis=0) + # wcd[conditions_C[1]] = np.mean(scores_b, axis=0) + + wcd=dict() + for condi in range(len(conditions_C)): + con_index=np.where(epochs_rs.metadata['Category'] == conditions_C[condi])[0] + group_x=X[con_index] + group_y=y[con_index] + + scores_per=np.zeros([100,group_x.shape[2]]) + for num_per in range(100): + # do the average trial + new_x = [] + new_y = [] + for label in range(len(conditions_O)): + # Extract the data: + data = group_x[np.where(group_y == label+1)] + data = np.take(data, np.random.permutation(data.shape[0]), axis=0) + avg_x = block_reduce(data, block_size=tuple([n_trials, *[1] * len(data.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_x.append(avg_x) + new_y += [label] * avg_x.shape[0] + + new_x = np.concatenate((new_x[0],new_x[1],new_x[2]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x=ATdata(new_x) + + scores= cross_val_multiscore(sliding, X=new_x, y=new_y, cv=5, n_jobs=1) + scores_per[num_per,:]=np.mean(scores, axis=0) + + wcd[conditions_C[condi]]=np.mean(scores_per, axis=0) + + + + + fig, ax = plt.subplots(1) + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + for condi, Cwcd in wcd.items(): + ax.plot(t, gaussian_filter1d(Cwcd,sigma=4), linewidth=3, label=str(condi), path_effects=pe) + ax.axhline(0.33,color='k',linestyle='--',label='chance') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title(f'WCD_ori_ {roi_name}') + ax.set(xlabel='Time(ms)', ylabel='decoding score') + mne.viz.tight_layout() + # Save figure + + #fname_fig = op.join('/home/user/S10/Cogitate/HPC/mvpa/ROI/figure/',ROI_Name[nroi] +"_CCD" + '.png') + fig.savefig(fname_fig) + + return wcd +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['face'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + analysis_name='Ori' # orientation decoding + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + roi_wcd_ori_acc = dict() + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, + rank, common_cov, + sub_info, surf_label_list[nroi]) + + + ### wcd_orientation + #1 scoring methods with balanced accuracy score + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_' + roi_name + "_acc_WCD_ori" + '.png') + + + score_methods=make_scorer(balanced_accuracy_score) + wcd_ori_acc= Orientation_WCD(epochs_rs, stcs, + conditions_C, + select_F, + n_trials, + # nPCA, + roi_name, score_methods, + fname_fig_acc) + + roi_wcd_ori_acc[roi_name] = wcd_ori_acc + + roi_data=dict() + + roi_data['wcd_ori_acc']=roi_wcd_ori_acc + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_ROIs_data_Ori" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D02_ROI_MVPA_Ori_PFC.py b/roi_mvpa/D02_ROI_MVPA_Ori_PFC.py new file mode 100644 index 0000000..7ea8ac2 --- /dev/null +++ b/roi_mvpa/D02_ROI_MVPA_Ori_PFC.py @@ -0,0 +1,480 @@ + +""" +==================== +D02. Decoding for MEG on source space of ROI +Orientation decoding +control analysis, +compare decoding performance with vs without PFC region. +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +compare the decodeing performance of postior region with or without prefrontal region + +""" +import warnings +import os.path as op + +import pickle + +import matplotlib.pyplot as plt +import mne +import numpy as np +import matplotlib as mpl + +import argparse + + + +from mne.decoding import (Vectorizer) + +from skimage.measure import block_reduce + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +#from sklearn.feature_selection import SelectPercentile, chi2 +#from sklearn.decomposition import PCA +#from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score + + +# from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import StratifiedKFold + + + +from scipy.ndimage import gaussian_filter1d +import matplotlib.patheffects as path_effects + + +#from config import no_eeg_sbj +#from config import site_id, subject_id, file_names, visit_id, data_path, out_path +# from config import l_freq, h_freq, sfreq +# from config import (bids_root, tmin, tmax) +import os +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root, plot_param + +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=DeprecationWarning) + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['F'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +# get the parameters dictionary +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +plt.rc('font', size=8) # controls default text size +plt.rc('axes', labelsize=20) +plt.rc('xtick',labelsize=18) +plt.rc('ytick',labelsize=18) +plt.rc('xtick.major', width=2, size=4) +plt.rc('ytick.major', width=2, size=4) +plt.rc('legend', fontsize=18) +# plt.rcParams["font.family"] = "serif" +# plt.rcParams["font.serif"] = "Times New Roman" +# plt.rc('font', size=param["font_size"]) # controls default text sizes +# plt.rc('axes', titlesize=param["font_size"]) # fontsize of the axes title +# plt.rc('axes', labelsize=param["font_size"]) # fontsize of the x and y labels +# plt.rc('xtick', labelsize=param["font_size"]) # fontsize of the tick labels +# plt.rc('ytick', labelsize=param["font_size"]) # fontsize of the tick labels +# plt.rc('legend', fontsize=param["font_size"]) # legend fontsize +# plt.rc('figure', titlesize=param["font_size"]) # fontsize of the fi +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} + + +mpl.rcParams.update(new_rc_params) + +# Color parameters: +cmap = "RdYlBu_r" + + +def Orientation_PFC(fpath_fw,rank,common_cov,sub_info,surf_label_list, + epochs_rs,conditions_C,conditions_D,conditions_T,task_info): + #get data + stcs_PFC = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[0]) + stcs_IIT = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[1]) + stcs_IITPFC = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[2]) + + + + + # setup SVM classifier + select_Fn=[30,30,60] + clf={} + for n,roi in enumerate(['PFC', 'IIT','IITPFC']): + clf[roi] = make_pipeline(Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_Fn[n]), + #SelectPercentile(chi2,k=select_p), + #(n_components=nPCA), + sklearn.svm.SVC(kernel='linear',class_weight='balanced',probability=True)) #LogisticRegression(), + + # # The scorers can be either one of the predefined metric strings or a scorer + # # callable, like the one returned by make_scorer + # #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # # score methods could be AUC or Accuracy + # # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + #sliding = SlidingEstimator(clf, scoring=make_scorer(accuracy_score), n_jobs=-1) + + + print(' Creating evoked datasets') + + + print(' Creating evoked datasets') + + conditions_O = ['Center', 'Left', 'Right'] + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Orientation'] == conditions_O[0]] = 1 # center, straight + temp[epochs_rs.metadata['Orientation'] == conditions_O[1]] = 2 # left,side orientation + temp[epochs_rs.metadata['Orientation'] == conditions_O[2]] = 3 # right, also side orientation + + times = epochs_rs.times + + y = temp + X_PFC=np.array([stc.data for stc in stcs_PFC]) + X_IIT=np.array([stc.data for stc in stcs_IIT]) + X_IITPFC=np.array([stc.data for stc in stcs_IITPFC]) + + # cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # # # Find indices of Irrelevant trials + # cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + wcd=dict() + + #con_index=np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] # only analysis Irrelevant condition + #con_index=np.where(epochs_rs.metadata['Category'] == conditions_C[condi])[0] + group_x_PFC=X_PFC + group_x_IIT=X_IIT + group_x_IITPFC=X_IITPFC + group_y=y + + + scores_per_IIT=np.zeros([100,len(times)]) + scores_per_IITPFC=np.zeros([100,len(times)]) + scores_per_comb=np.zeros([100,len(times)]) + scores_per_comb_bayes=np.zeros([100,len(times)]) + for num_per in range(100): + # do the average trial + new_x_PFC = [] + new_x_IIT = [] + new_x_IITPFC = [] + new_y = [] + for label in range(3): + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + #PFC + data_PFC = group_x_PFC[np.where(group_y == label+1)] + data_PFC = np.take(data_PFC, np.random.permutation(data_PFC.shape[0]), axis=0) + avg_x_PFC = block_reduce(data_PFC, block_size=tuple([n_trials, *[1] * len(data_PFC.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_PFC.append(avg_x_PFC) + + #IIT + data_IIT = group_x_IIT[np.where(group_y == label+1)] + data_IIT = np.take(data_IIT, np.random.permutation(data_IIT.shape[0]), axis=0) + avg_x_IIT = block_reduce(data_IIT, block_size=tuple([n_trials, *[1] * len(data_IIT.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_IIT.append(avg_x_IIT) + + + #IITPFC + data_IITPFC = group_x_IITPFC[np.where(group_y == label+1)] + data_IITPFC = np.take(data_IITPFC, np.random.permutation(data_IITPFC.shape[0]), axis=0) + avg_x_IITPFC = block_reduce(data_IITPFC, block_size=tuple([n_trials, *[1] * len(data_IITPFC.shape[1:])]), + func=np.nanmean, cval=np.nan) + new_x_IITPFC.append(avg_x_IITPFC) + + + # Now generating the labels and group: + new_y += [label] * avg_x_PFC.shape[0] + + new_x_PFC = np.concatenate((new_x_PFC[0],new_x_PFC[1],new_x_PFC[2]),axis=0) + new_x_IIT = np.concatenate((new_x_IIT[0],new_x_IIT[1],new_x_IIT[2]),axis=0) + new_x_IITPFC = np.concatenate((new_x_IITPFC[0],new_x_IITPFC[1],new_x_IITPFC[2]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x_PFC=ATdata(new_x_PFC) + new_x_IIT=ATdata(new_x_IIT) + new_x_IITPFC=ATdata(new_x_IITPFC) + + skf = StratifiedKFold(n_splits=5) + # Getting the indices of the test and train sets from cross folder validation: + cv_index = list(skf.split(new_x_PFC, new_y)) + + + n_classes=3 + n_folds=5 + # initialize storage + decoding_scores_IIT = np.empty((n_folds, len(times))) + decoding_scores_IITPFC = np.empty((n_folds, len(times))) + decoding_scores_comb = np.empty((n_folds, len(times))) + decoding_scores_comb_bayes = np.empty((n_folds, len(times))) + proba_IIT = np.zeros((len(new_y), n_classes, len(times)))*np.nan + proba_PFC = np.zeros((len(new_y), n_classes, len(times)))*np.nan + + + + + for ind, train_test_ind in enumerate(cv_index): + y_train = new_y[train_test_ind[0]] + y_test = new_y[train_test_ind[1]] + for t, time in enumerate(times): + x_train_PFC = new_x_PFC[train_test_ind[0], :, t] + x_test_PFC = new_x_PFC[train_test_ind[1], :, t] + + x_train_IIT = new_x_IIT[train_test_ind[0], :, t] + x_test_IIT = new_x_IIT[train_test_ind[1], :, t] + + x_train_IITPFC = new_x_IITPFC[train_test_ind[0], :, t] + x_test_IITPFC = new_x_IITPFC[train_test_ind[1], :, t] + + # # original code w/o calibration + # # regular prediction for iit-alone + # mdl_iit = clf['iit'].fit(x_train_iit, y_train) + # mdl_gnw = clf['gnw'].fit(x_train_gnw, y_train) + + # y_pred = mdl_iit.predict(x_test_iit) + # decoding_scores_iit[ind,t] = balanced_accuracy_score(y_test, y_pred ) + + # iit+gnw feature model + mdl_IITPFC = clf['IITPFC'].fit(x_train_IITPFC, y_train) + + mdl_IIT = clf['IIT'].fit(x_train_IIT, y_train) + mdl_PFC = clf['PFC'].fit(x_train_PFC, y_train) + + # iit-only + y_pred = mdl_IIT.predict(x_test_IIT) + decoding_scores_IIT[ind,t] = accuracy_score(y_test, y_pred ) + + # iit+pfc feature model + y_pred = mdl_IITPFC.predict( x_test_IITPFC ) + decoding_scores_IITPFC[ind,t] = accuracy_score(y_test, y_pred ) + + # for iit+pfc model, get posterior probabilities, sum them, then norm the result (softmax), and predict the label + mdl_prob_IIT = mdl_IIT.predict_proba( x_test_IIT ) + mdl_prob_PFC = mdl_PFC.predict_proba( x_test_PFC ) + + # store the probabilities + proba_IIT[train_test_ind[1], :, t] = mdl_prob_IIT + proba_PFC[train_test_ind[1], :, t] = mdl_prob_PFC + + psum = mdl_prob_IIT+mdl_prob_PFC + softmx = np.exp(psum) / np.expand_dims( np.sum(np.exp(psum),1),1) + ypred_combined = np.argmax( softmx, 1) + decoding_scores_comb[ind,t] = accuracy_score(y_test, mdl_IIT.classes_[ ypred_combined ] ) + + # p_post = 1/( 1 + exp(log((1-Pgnw)/Pgnw) - log(Piit/(1-Piit)) ) ) + PIIT = mdl_prob_IIT + PPFC = mdl_prob_PFC + bayes_int = 1/( 1 + np.exp(np.log((1-PPFC)/PPFC) - np.log(PIIT/(1-PPFC)) ) ) + ypred_combined = np.argmax( bayes_int, 1) + decoding_scores_comb_bayes[ind,t] = accuracy_score(y_test, mdl_IIT.classes_[ ypred_combined ] ) + + + + + + scores_per_IIT[num_per,:]=np.mean(decoding_scores_IIT, axis=0) + scores_per_IITPFC[num_per,:]=np.mean(decoding_scores_IITPFC, axis=0) + scores_per_comb[num_per,:]=np.mean(decoding_scores_comb, axis=0) + scores_per_comb_bayes[num_per,:]=np.mean(decoding_scores_comb_bayes, axis=0) + + wcd['IIT']=np.mean(scores_per_IIT, axis=0) + wcd['IITPFC_f']=np.mean(scores_per_IITPFC, axis=0) # feature combine score + wcd['IITPFC_m']=np.mean(scores_per_comb, axis=0) # model combine score + wcd['IITPFC_m_bayes']=np.mean(scores_per_comb_bayes, axis=0) # model combine score with bayes methods + + + + # wcd=dict() + # scores_a= cross_val_multiscore(sliding, X=X[cond_a], y=y[cond_a], cv=5, n_jobs=1) + # wcd[conditions_D[0]]=np.mean(scores_a, axis=0) + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=1) + # wcd[conditions_D[1]] = np.mean(scores_b, axis=0) + + # pattern = dict() + # pattern['IR'] = coef_a + # pattern['RE'] = coef_b + + + + return wcd + + + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + + analysis_name='Ori_PFC' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + roi_ccd_acc = dict() + #roi_ccd_auc = dict() + roi_wcd_acc = dict() + + + fname_fig = op.join(roi_figure_root,sub_info + task_info + '_' + "IITPFC_acc_WCD_Ori" + '.png') + + wcd_acc=Orientation_PFC(fpath_fw,rank,common_cov,sub_info,surf_label_list, + epochs_rs,conditions_C,conditions_D,conditions_T,task_info) + + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_IITPFC_data_Ori" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(wcd_acc,fw) + fw.close() + + + + fig, ax = plt.subplots(1) + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + for condi, Ti_name in wcd_acc.items(): + ax.plot(t, gaussian_filter1d(Ti_name,sigma=4), linewidth=1, label=str(condi), path_effects=pe) + ax.axhline(0.5,color='k',linestyle='--',label='chance') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title('WCD_IIT_PFC') + ax.set(xlabel='Time(ms)', ylabel='decoding score') + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig) + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D03_ROI_MVPA_GAT_Cat.py b/roi_mvpa/D03_ROI_MVPA_GAT_Cat.py new file mode 100644 index 0000000..743a676 --- /dev/null +++ b/roi_mvpa/D03_ROI_MVPA_GAT_Cat.py @@ -0,0 +1,553 @@ + +""" +==================== +D03. Decoding for MEG on source space of ROI : genelaization across time (GAT) +Category Decoding +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CTWCD: Cross Time Within Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +""" + +import os.path as op +import pickle + +import matplotlib.pyplot as plt + +import numpy as np +import matplotlib as mpl + +import argparse + + + +from mne.decoding import (Vectorizer, cross_val_multiscore) +# import a linear classifier from mne.decoding +from mne.decoding import LinearModel +from mne.decoding import GeneralizingEstimator + + + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif + +from sklearn.metrics import make_scorer +from sklearn.metrics import accuracy_score +# from sklearn.linear_model import LogisticRegression +# from sklearn.model_selection import StratifiedKFold + +from skimage.measure import block_reduce + + +from scipy.ndimage import gaussian_filter +import matplotlib.patheffects as path_effects + +import os +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') + +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + + + + # Now we define a function to decoding condition for one subject + # Category_CTCCD, train on condition A, test on condition B +def Category_CTCCD(epochs_rs,stcs,conditions_C,conditions_D, + select_F, + roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + LinearModel(sklearn.svm.SVC( + kernel='linear'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + sliding = GeneralizingEstimator(clf, scoring=score_methods, n_jobs=-1) + + + print(' Creating evoked datasets') + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Category'] == conditions_C[0]] = 1 # face + temp[epochs_rs.metadata['Category'] == conditions_C[1]] = 2 # object + + y = temp + X=np.array([stc.data for stc in stcs]) + + cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # Find indices of Irrelevant trials + cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + + + # # Run cross-validated decoding analyses: + # scores_a = cross_val_multiscore(sliding,X=X[cond_a], y=y[cond_a],cv=5,n_jobs=-1) + # # Run cross-validated decoding analyses: + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=-1) + + # First: train condition a (cond_a) and Test on condition b (cond_b) cross condition decoding + # Fit + ctccd=dict() + + group_xa=X[cond_a] + group_ya=y[cond_a] + group_xb=X[cond_b] + group_yb=y[cond_b] + + scores_ab_per=np.zeros([100,group_xa.shape[2],group_xa.shape[2]]) + scores_ba_per=np.zeros([100,group_xb.shape[2],group_xa.shape[2]]) + for num_per in range(100): + # do the average trial + new_xa = [] + new_ya = [] + new_xb = [] + new_yb = [] + for label in range(2): + # Extract the data: + data_a = group_xa[np.where(group_ya == label+1)] + data_a = np.take(data_a, np.random.permutation(data_a.shape[0]), axis=0) + avg_xa = block_reduce(data_a, block_size=tuple([n_trials, *[1] * len(data_a.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_xa.append(avg_xa) + new_ya += [label] * avg_xa.shape[0] + + # Extract the data: + data_b = group_xb[np.where(group_yb == label+1)] + data_b = np.take(data_b, np.random.permutation(data_b.shape[0]), axis=0) + avg_xb = block_reduce(data_b, block_size=tuple([n_trials, *[1] * len(data_b.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_xb.append(avg_xb) + new_yb += [label] * avg_xb.shape[0] + + new_xa = np.concatenate((new_xa[0],new_xa[1]),axis=0) + new_ya = np.array(new_ya) + + # average temporal feature (5 point average) + new_xa=ATdata(new_xa) + + new_xb = np.concatenate((new_xb[0],new_xb[1]),axis=0) + new_yb = np.array(new_yb) + + # average temporal feature (5 point average) + new_xb=ATdata(new_xb) + + # First: train condition a (cond_a) and Test on condition b (cond_b) cross condition decoding + # Fit + sliding.fit(X=new_xa, y=new_ya) + # Test + scores_ab = sliding.score(X=new_xb, y=new_yb) + + + scores_ab_per[num_per,:,:]=scores_ab + + # Then: train condition b (cond_b) and Test on condition a (cond_a) cross condition decoding + # Fit + sliding.fit(X=new_xb, y=new_yb) + # Test + scores_ba = sliding.score(X=new_xa, y=new_ya) + + + scores_ba_per[num_per,:,:]=scores_ba + + + ctccd['IR2RE'] = scores_ab + ctccd['RE2IR'] = scores_ba + + fig, axes = plt.subplots(1, 2,figsize=(10,3),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + fig.suptitle('CTCCD') + + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + cmap = mpl.cm.jet + vmin = 0.5 + vmax = 0.7 + bounds = np.linspace(vmin, vmax, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + # diff setting + vmind = -0.15 + vmaxd = 0.15 + boundsd = np.linspace(vmind, vmaxd, 11) + normd = mpl.colors.BoundaryNorm(boundsd, cmap.N) + #plot + im = axes[0].imshow(gaussian_filter(scores_ab,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=epochs_rs.times[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes[0].set_xlabel('Testing Time (s)') + axes[0].set_ylabel('Training Time (s)') + axes[0].set_title('Train IR Test RE') + axes[0].axvline(0, color='k') + axes[0].axhline(0, color='k') + axes[0].axline((0, 0), slope=1, color='k') + plt.colorbar(im, ax=axes[0],fraction=0.03, pad=0.05) + + im = axes[1].imshow(gaussian_filter(scores_ba, sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=epochs_rs.times[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes[1].set_xlabel('Testing Time (s)') + axes[1].set_ylabel('Training Time (s)') + axes[1].set_title('Train RE Test IR') + axes[1].axvline(0, color='k') + axes[1].axhline(0, color='k') + axes[1].axline((0,0), slope=1, color='k') + plt.colorbar(im, ax=axes[1],fraction=0.03, pad=0.05) + + # Save figure + + fig.savefig(fname_fig) + + return ctccd + + + +#cross time within condition decoding +def Category_CTWCD(epochs_rs,stcs, + conditions_C,conditions_D, + seletct_F, + roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + #SelectPercentile(chi2,k=select_p), + #PCA(n_components=nPCA), + LinearModel(sklearn.svm.SVC( + kernel='linear'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": make_scorer(accuracy_score)}# + + sliding = GeneralizingEstimator(clf, scoring=score_methods, n_jobs=-1) + + + print(' Creating evoked datasets') + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Category'] == conditions_C[0]] = 1 # face + temp[epochs_rs.metadata['Category'] == conditions_C[1]] = 2 # object + + y = temp + X=np.array([stc.data for stc in stcs]) + + # cond_a = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[0])[0] + # # # Find indices of Irrelevant trials + # # cond_b = np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[1])[0] + + + # wcd=dict() + # scores_a= cross_val_multiscore(sliding, X=X[cond_a], y=y[cond_a], cv=5, n_jobs=2) + # wcd[conditions_D[0]]=np.mean(scores_a, axis=0) + + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=-1) + # wcd[conditions_D[1]] = np.mean(scores_b, axis=0) + + ctwcd=dict() + for condi in range(len(conditions_D)): + con_index=np.where(epochs_rs.metadata['Task_relevance'] == conditions_D[condi])[0] + group_x=X[con_index] + group_y=y[con_index] + + scores_per=np.zeros([100,group_x.shape[2],group_x.shape[2]]) + for num_per in range(100): + # do the average trial + new_x = [] + new_y = [] + for label in range(2): + # Extract the data: + data = group_x[np.where(group_y == label+1)] + data = np.take(data, np.random.permutation(data.shape[0]), axis=0) + avg_x = block_reduce(data, block_size=tuple([n_trials, *[1] * len(data.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_x.append(avg_x) + new_y += [label] * avg_x.shape[0] + + new_x = np.concatenate((new_x[0],new_x[1]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x=ATdata(new_x) + + scores= cross_val_multiscore(sliding, X=new_x, y=new_y, cv=5, n_jobs=-1) + scores_per[num_per,:,:]=np.mean(scores, axis=0) + + ctwcd[conditions_D[condi]]=np.mean(scores_per, axis=0) + + + + + fig, axes = plt.subplots(1, 2,figsize=(10,3),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + fig.suptitle('CTWCD') + + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + cmap = mpl.cm.jet + vmin = 0.5 + vmax = 0.7 + bounds = np.linspace(vmin, vmax, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + # diff setting + vmind = -0.15 + vmaxd = 0.15 + boundsd = np.linspace(vmind, vmaxd, 11) + normd = mpl.colors.BoundaryNorm(boundsd, cmap.N) + #plot + im = axes[0].imshow(gaussian_filter(ctwcd[conditions_D[1]],sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=epochs_rs.times[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes[0].set_xlabel('Testing Time (s)') + axes[0].set_ylabel('Training Time (s)') + axes[0].set_title('Train RE Test RE') + axes[0].axvline(0, color='k') + axes[0].axhline(0, color='k') + axes[0].axline((0, 0), slope=1, color='k') + plt.colorbar(im, ax=axes[0],fraction=0.03, pad=0.05) + + im = axes[1].imshow(gaussian_filter(ctwcd[conditions_D[0]], sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=epochs_rs.times[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes[1].set_xlabel('Testing Time (s)') + axes[1].set_ylabel('Training Time (s)') + axes[1].set_title('Train IR Test IR') + axes[1].axvline(0, color='k') + axes[1].axhline(0, color='k') + axes[1].axline((0,0), slope=1, color='k') + plt.colorbar(im, ax=axes[1],fraction=0.03, pad=0.05) + + # Save figure + + fig.savefig(fname_fig) + + return ctwcd + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + analysis_name='GAT_Cat' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + + + roi_ctccd_acc = dict() + roi_ctwcd_acc = dict() + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + # ### CTCCD + + # #1 scoring methods with accuracy score + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + + "_acc_CTCCD" + '.png') + + + score_methods=make_scorer(accuracy_score) + ctccd_acc = Category_CTCCD(epochs_rs, stcs, + conditions_C, conditions_D, + select_F, + roi_name, score_methods, + fname_fig_acc) + + roi_ctccd_acc[roi_name] = ctccd_acc + + + ### CTWCD + + #1 scoring methods with accuracy score + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_' + + roi_name + "_acc_CTWCD" + '.png') + + + score_methods=make_scorer(accuracy_score) + ctwcd_acc= Category_CTWCD(epochs_rs, stcs, + conditions_C, conditions_D, + select_F, + roi_name, score_methods, + fname_fig_acc) + + roi_ctwcd_acc[roi_name] = ctwcd_acc + + + roi_data=dict() + roi_data['ctccd_acc']=roi_ctccd_acc + roi_data['ctwcd_acc']=roi_ctwcd_acc + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_ROIs_data_GAT_Cat" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D04_ROI_MVPA_GAT_Ori.py b/roi_mvpa/D04_ROI_MVPA_GAT_Ori.py new file mode 100644 index 0000000..2ba8455 --- /dev/null +++ b/roi_mvpa/D04_ROI_MVPA_GAT_Ori.py @@ -0,0 +1,353 @@ + +""" +==================== +D04. Decoding for MEG on source space of ROI : genelaization across time (GAT) +Orienation Decoding +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CCD: Cross Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +feature selection methods test + +""" + +import os.path as op + +import pickle + +import matplotlib.pyplot as plt + +import numpy as np +import matplotlib as mpl + +import argparse + + + +from mne.decoding import (Vectorizer, cross_val_multiscore) +# import a linear classifier from mne.decoding +from mne.decoding import LinearModel +from mne.decoding import GeneralizingEstimator + + +from skimage.measure import block_reduce + +import sklearn.svm +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +#from sklearn.feature_selection import SelectPercentile, chi2 + +from sklearn.metrics import make_scorer +from sklearn.metrics import balanced_accuracy_score + +# from sklearn.linear_model import LogisticRegression +# from sklearn.model_selection import StratifiedKFold + + + + +from scipy.ndimage import gaussian_filter +import matplotlib.patheffects as path_effects + +import os +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['500ms','1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + + + + +def Orientation_CTWCD(epochs_rs,stcs,conditions_C,select_F,n_trials,roi_name,score_methods,fname_fig): + # setup SVM classifier + clf = make_pipeline( + Vectorizer(), + StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales + SelectKBest(f_classif,k=select_F), + #SelectPercentile(chi2,k=select_p), + #PCA(n_components=nPCA), + LinearModel(sklearn.svm.SVC( + kernel='linear',class_weight='balanced'))) #LogisticRegression(), + + # The scorers can be either one of the predefined metric strings or a scorer + # callable, like the one returned by make_scorer + #scoring = {"Accuracy": make_scorer(accuracy_score)}#"AUC": "roc_auc", + # score methods could be AUC or Accuracy + # {"AUC": "roc_auc","Accuracy": }# + # For multivariable decoding(e,g, 1,2,3, could not use roc_auc), + # deal with unbalanced trial number, score should use make_scorer(balanced_accuracy_score) + + sliding = GeneralizingEstimator(clf, scoring=score_methods, n_jobs=-1) + + + print(' Creating evoked datasets') + + conditions_O = ['Center', 'Left', 'Right'] + + temp = epochs_rs.events[:, 2] + temp[epochs_rs.metadata['Orientation'] == conditions_O[0]] = 1 # center, straight + temp[epochs_rs.metadata['Orientation'] == conditions_O[1]] = 2 # left,side orientation + temp[epochs_rs.metadata['Orientation'] == conditions_O[2]] = 3 # right, also side orientation + + + y = temp + X=np.array([stc.data for stc in stcs]) + + # cond_a = np.where(epochs_rs.metadata['Category'] == conditions_C[0])[0] + # # Find indices of Object trials + # cond_b = np.where(epochs_rs.metadata['Category'] == conditions_C[1])[0] + + # # # Run cross-validated decoding analyses: + # scores_a = cross_val_multiscore(sliding,X=X[cond_a], y=y[cond_a],cv=5,n_jobs=1) + # # # Run cross-validated decoding analyses: + # scores_b = cross_val_multiscore(sliding, X=X[cond_b], y=y[cond_b], cv=5, n_jobs=1) + + # # + # wcd = dict() + # # ccd['IR'] = np.mean(scores_a, axis=0) + # # ccd['RE'] = np.mean(scores_b, axis=0) + # wcd[conditions_C[0]] = np.mean(scores_a, axis=0) + # wcd[conditions_C[1]] = np.mean(scores_b, axis=0) + + ctwcd=dict() + for condi in range(len(conditions_C)): + con_index=np.where(epochs_rs.metadata['Category'] == conditions_C[condi])[0] + group_x=X[con_index] + group_y=y[con_index] + + scores_per=np.zeros([100,group_x.shape[2],group_x.shape[2]]) + for num_per in range(100): + # do the average trial + new_x = [] + new_y = [] + for label in range(len(conditions_O)): + # Extract the data: + data = group_x[np.where(group_y == label+1)] + data = np.take(data, np.random.permutation(data.shape[0]), axis=0) + avg_x = block_reduce(data, block_size=tuple([n_trials, *[1] * len(data.shape[1:])]), + func=np.nanmean, cval=np.nan) + #block_size + #array_like or int + #Array containing down-sampling integer factor along each axis. Default block_size is 2. + + # funccallable + # Function object which is used to calculate the return value for each local block. This function must implement an axis parameter. Primary functions are numpy.sum, numpy.min, numpy.max, numpy.mean and numpy.median. See also func_kwargs. + + # cvalfloat + # Constant padding value if image is not perfectly divisible by the block size. + + # Now generating the labels and group: + new_x.append(avg_x) + new_y += [label] * avg_x.shape[0] + + new_x = np.concatenate((new_x[0],new_x[1],new_x[2]),axis=0) + new_y = np.array(new_y) + + # average temporal feature (5 point average) + new_x=ATdata(new_x) + + scores= cross_val_multiscore(sliding, X=new_x, y=new_y, cv=5, n_jobs=-1) + scores_per[num_per,:,:]=np.mean(scores, axis=0) + + ctwcd[conditions_C[condi]]=np.mean(scores_per, axis=0) + + + + + fig, axes = plt.subplots(1) + plt.subplots_adjust(wspace=0.5, hspace=0) + fig.suptitle('ORI CTWCD') + + t = 1e3 * epochs_rs.times + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + cmap = mpl.cm.jet + vmin = 0.5 + vmax = 0.7 + bounds = np.linspace(vmin, vmax, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + # diff setting + vmind = -0.15 + vmaxd = 0.15 + boundsd = np.linspace(vmind, vmaxd, 11) + normd = mpl.colors.BoundaryNorm(boundsd, cmap.N) + #plot + im = axes.imshow(gaussian_filter(ctwcd[conditions_C[0]],sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=epochs_rs.times[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.set_xlabel('Testing Time (s)') + axes.set_ylabel('Training Time (s)') + axes.axvline(0, color='k') + axes.axhline(0, color='k') + axes.axline((0, 0), slope=1, color='k') + plt.colorbar(im, ax=axes,fraction=0.03, pad=0.05) + + #fname_fig = op.join('/home/user/S10/Cogitate/HPC/mvpa/ROI/figure/',ROI_Name[nroi] +"_CCD" + '.png') + fig.savefig(fname_fig) + + return ctwcd +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['face'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + + analysis_name='GAT_Ori' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id, analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + roi_ctwcd_ori_acc = dict() + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, + rank, common_cov, + sub_info, surf_label_list[nroi]) + + + ### wcd_orientation + #1 scoring methods with balanced accuracy score + fname_fig_acc = op.join(roi_figure_root, + sub_info + task_info + '_' + roi_name + "_acc_CTWCD_ori" + '.png') + + + score_methods=make_scorer(balanced_accuracy_score) + ctwcd_ori_acc= Orientation_CTWCD(epochs_rs, stcs, + conditions_C, + select_F, + n_trials, + # nPCA, + roi_name, score_methods, + fname_fig_acc) + + roi_ctwcd_ori_acc[roi_name] = ctwcd_ori_acc + + roi_data=dict() + + roi_data['ctwcd_ori_acc']=roi_ctwcd_ori_acc + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info +"_ROIs_data_GAT_Ori" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D05_ROI_MVPA_RSA_Cat.py b/roi_mvpa/D05_ROI_MVPA_RSA_Cat.py new file mode 100644 index 0000000..5c07cfb --- /dev/null +++ b/roi_mvpa/D05_ROI_MVPA_RSA_Cat.py @@ -0,0 +1,357 @@ + +""" +==================== +D05. RSA for MEG on source space of ROI +Category RSA +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CTWCD: Cross Time Within Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +""" +import os +import os.path as op +#import joblib +import pickle + +import matplotlib.pyplot as plt +import numpy as np +import matplotlib as mpl + +import argparse + + +from joblib import Parallel, delayed +from tqdm import tqdm + + +#from scipy.ndimage import gaussian_filter1d +from scipy.ndimage import gaussian_filter +import matplotlib.patheffects as path_effects + + +#from config import no_eeg_sbj +#from config import site_id, subject_id, file_names, visit_id, data_path, out_path + + + +from rsa_helper_functions_meg import pseudotrials_rsa_all2all + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from D_MEG_function import set_path_ROI_MVPA, sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +parser.add_argument('--nPerm', + type=int, + default=100, + help='number of Permuation for pseudo-trials, if debug, could set to 2') + +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA +per_num=opt.nPerm +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +def Category_RSA(epochs, stcs, conditions_C,conD,n_features=None, n_pseudotrials=5, n_iterations=20, n_jobs=-1, feat_sel_diag=True): + # Find the indices of the relevant conditions: + epochs.metadata["order"] = list(range(len(epochs.metadata))) + meta_data = epochs.metadata + trials_inds = meta_data.loc[((meta_data["Category"] == conditions_C[0]) | (meta_data["Category"] == conditions_C[1])) & + (meta_data["Task_relevance"] == conD), "order"].to_list() + # Extract these trials: + epochs = epochs[trials_inds] + # Aggregate single trials stcs into a numpy array: + X = np.array([stc.data for stc in stcs]) + # Select the trials of interest: + X = X[trials_inds, :, :] + # Get the labels: + y = epochs.metadata["Category"].to_numpy() + + rsa_results, rdm_diag, sel_features = \ + zip(*Parallel(n_jobs=n_jobs)(delayed(pseudotrials_rsa_all2all)( + X, y, n_pseudotrials, epochs.times, sample_rdm_times=None,n_features=n_features,metric="correlation", + fisher_transform=True,feat_sel_diag=feat_sel_diag + ) for i in tqdm(range(n_iterations)))) + return rsa_results, rdm_diag, sel_features + + +def Plot_RSA(rsa, sample, roi_name,fname_fig): + ##RSA plot## + + fig, axes = plt.subplots(1, 2,figsize=(10,3),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + fig.suptitle(f'RSA_Cat_ {roi_name}') + time_point = np.array(range(-500,2001, 10))/1000 + t = time_point + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + cmap = mpl.cm.jet +# bounds = np.linspace(vmin, vmax, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + # diff setting + # vmind = -0.15 + # vmaxd = 0.15 +# boundsd = np.linspace(vmind, vmaxd, 11) +# normd = mpl.colors.BoundaryNorm(boundsd, cmap.N) + #plot + for condi, Ti_name in rsa.items(): + if condi=='Irrelevant': + im = axes[0].imshow(gaussian_filter(Ti_name,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=t[[0, -1, 0, -1]]) + axes[0].set(xlabel='First', ylabel='Second') + axes[0].set_title(f'RSA_IR_ {roi_name}') + axes[0].axvline(0, color='k') + axes[0].axhline(0, color='k') + axes[0].axline((0, 0), slope=1, color='k') + plt.colorbar(im, ax=axes[0],fraction=0.03, pad=0.05) + elif condi=='Relevant non-target': + im = axes[1].imshow(gaussian_filter(Ti_name,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap, + extent=t[[0, -1, 0, -1]]) + axes[1].set(xlabel='First', ylabel='Second') + axes[1].set_title(f'RSA_RE_ {roi_name}') + axes[1].axvline(0, color='k') + axes[1].axhline(0, color='k') + axes[1].axline((0, 0), slope=1, color='k') + plt.colorbar(im, ax=axes[1],fraction=0.03, pad=0.05) + + # Save figure + + fig.savefig(op.join(fname_fig+ "_rsa_Cat" + '.png')) + + +# ##sample RDM plot## + +# fig, axes = plt.subplots(1, 2,figsize=(10,3),sharex=True,sharey=True) +# plt.subplots_adjust(wspace=0.5, hspace=0) +# fig.suptitle(f'Sample_RDM_ {roi_name}') + + +# pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] +# cmap = mpl.cm.jet +# # bounds = np.linspace(vmin, vmax, 11) +# # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) +# # diff setting +# # vmind = -0.15 +# # vmaxd = 0.15 +# # boundsd = np.linspace(vmind, vmaxd, 11) +# # normd = mpl.colors.BoundaryNorm(boundsd, cmap.N) +# #plot +# for condi, Ti_name in sample.items(): +# if condi=='Irrelevant': +# im = axes[0].imshow(Ti_name, interpolation='lanczos', origin='lower', cmap=cmap, +# extent=t[[0, -1, 0, -1]]) +# axes[0].set(xlabel='First', ylabel='Second') +# axes[0].set_title(f'Sample_IR_ {roi_name}') +# axes[0].axvline(0, color='k') +# axes[0].axhline(0, color='k') +# axes[0].axline((0, 0), slope=1, color='k') +# plt.colorbar(im, ax=axes[0],fraction=0.03, pad=0.05) +# elif condi=='Relevant non-target': +# im = axes[1].imshow(Ti_name, interpolation='lanczos', origin='lower', cmap=cmap, +# extent=t[[0, -1, 0, -1]]) +# axes[1].set(xlabel='First', ylabel='Second') +# axes[1].set_title(f'Sample_RE_ {roi_name}') +# axes[1].axvline(0, color='k') +# axes[1].axhline(0, color='k') +# axes[1].axline((0, 0), slope=1, color='k') +# plt.colorbar(im, ax=axes[1],fraction=0.03, pad=0.05) + +# # Save figure + +# fig.savefig(op.join(fname_fig+ "_sample_rdm_Cat" + '.png')) + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + #metric="correlation" or metric='euclidean' + + analysis_name='RSA_Cat' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + + + roi_rsa = dict() + roi_sample = dict() + roi_feature = dict() + + + + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + + + + + if roi_name=='GNW': + sample_times=[0.3, 0.5] + else: + sample_times=[0.3, 1.5] + + cT_rsa = dict() + cT_sample = dict() + cT_features = dict() + + for nd, conD in enumerate(conditions_D): + rsa, sample, sel_features = Category_RSA(epochs_rs, stcs, conditions_C, conD, n_features=None) + + # converting dictionary to + # numpy array + rsa_array = np.asarray(rsa) + sample_array = np.asarray(sample) + features_array = np.asarray(sel_features) + + + + cT_rsa[conD] = np.mean(rsa_array, axis=0) + cT_sample[conD] = np.mean(sample_array, axis=0) + cT_features[conD] = features_array + + roi_rsa[roi_name]=cT_rsa + roi_sample[roi_name] = cT_sample + roi_feature[roi_name] = cT_features + + roi_data=dict() + roi_data['rsa']=roi_rsa + roi_data['sample']=roi_sample + roi_data['feature']=roi_feature + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info + roi_name + "_ROIs_data_RSA_Cat" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + #pot results + # #1 scoring methods with accuracy score + fname_fig = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + ) + Plot_RSA(cT_rsa, cT_sample, roi_name,fname_fig) + + + + + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D06_ROI_MVPA_RSA_Ori.py b/roi_mvpa/D06_ROI_MVPA_RSA_Ori.py new file mode 100644 index 0000000..783e26d --- /dev/null +++ b/roi_mvpa/D06_ROI_MVPA_RSA_Ori.py @@ -0,0 +1,299 @@ + +""" +==================== +D05. RSA for MEG on source space of ROI +Orientation RSA +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CTWCD: Cross Time Within Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) +without feature selection + +""" + +import os.path as op +import pickle + +from joblib import Parallel, delayed +from tqdm import tqdm + +import matplotlib.pyplot as plt +import mne +import numpy as np +import matplotlib as mpl +import argparse + +import os +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from rsa_helper_functions_meg import pseudotrials_rsa_all2all +from D_MEG_function import set_path_ROI_MVPA, sensor_data_for_ROI_MVPA +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['F'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +parser.add_argument('--nPerm', + type=int, + default=100, + help='number of Permuation for pseudo-trials, if debug, could set to 2') +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA +per_num=opt.nPerm +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +def Orientation_RSA(epochs, stcs, conditions_C,n_features=None, n_pseudotrials=5, n_iterations=20, n_jobs=-1, feat_sel_diag=True): + X = np.array([stc.data for stc in stcs]) + + # Get the labels: + conditions_O = ['Center', 'Left', 'Right'] + + temp = epochs.events[:, 2] + temp[epochs.metadata['Orientation'] == conditions_O[0]] = 1 # center, straight + temp[epochs.metadata['Orientation'] == conditions_O[1]] = 2 # left,side orientation + temp[epochs.metadata['Orientation'] == conditions_O[2]] = 2 # right, also side orientation + + y = temp + #y = epochs.metadata["Category"].to_numpy() + + rsa_results, rdm_diag, sel_features = \ + zip(*Parallel(n_jobs=n_jobs)(delayed(pseudotrials_rsa_all2all)( + X, y, n_pseudotrials, epochs.times, sample_rdm_times=None, + n_features=n_features,metric="correlation", + fisher_transform=True, feat_sel_diag=feat_sel_diag + ) for i in tqdm(range(n_iterations)))) + return rsa_results, rdm_diag, sel_features + +def Plot_RSA(rsa, sample, roi_name,fname_fig): + fig, ax = plt.subplots(1) + + time_point = np.array(range(-500,2001, 10))/1000 + cmap = mpl.cm.jet + im=ax.imshow(rsa, interpolation='lanczos', origin='lower', cmap=cmap,extent=time_point[[0, -1, 0, -1]])#, vmin=vmin, vmax=vmax)#, norm=norm + ax.axhline(0,color='k') + ax.axvline(0, color='k') + #ax.legend(loc='upper right') + ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Times', ylabel='Times') + plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + mne.viz.tight_layout() + # Save figure + + fig.savefig(op.join(fname_fig+ "_rsa_ori" + '.png')) + + # fig, ax = plt.subplots(1) + + + # trial_index= np.array(range(0, sample.shape[0], 1)) + # #GAT setting + # cmap = mpl.cm.jet + # im=ax.imshow(sample, interpolation='lanczos', origin='lower', cmap=cmap,extent=trial_index[[0, -1, 0, -1]])#, vmin=vmin, vmax=vmax)#, norm=norm + # ax.axhline(0,color='k') + # ax.axvline(0, color='k') + # #ax.legend(loc='upper right') + # ax.set_title(f'Sample_RDM_ {roi_name}') + # ax.set(xlabel='First', ylabel='Second') + # plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + # mne.viz.tight_layout() + # # Save figure + + # fig.savefig(op.join(fname_fig+ "_sample_rdm_ori" + '.png')) + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + #metric="correlation" or metric='euclidean' + + + analysis_name='RSA_Ori' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA(fpath_epo, + sub_info, + con_T, + con_C, + con_D) + + + + roi_rsa = dict() + roi_sample = dict() + roi_feature = dict() + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + + + + + + ### CTCCD + + # #1 scoring methods with accuracy score + fname_fig = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + ) + + if roi_name=='GNW': + sample_times=[0.3, 0.5] + else: + sample_times=[0.3, 1.5] + + + cT_rsa = dict() + cT_sample = dict() + cT_features = dict() + + + rsa, sample, sel_features = Orientation_RSA(epochs_rs, stcs, conditions_C, n_features=None) + + # converting dictionary to + # numpy array + rsa_array = np.asarray(rsa) + sample_array = np.asarray(sample) + features_array = np.asarray(sel_features) + + + + cT_rsa = np.mean(rsa_array, axis=0) + cT_sample = np.mean(sample_array, axis=0) + cT_features = features_array + + + roi_rsa[roi_name]=cT_rsa + roi_sample[roi_name] = cT_sample + roi_feature[roi_name] = cT_features + + roi_data=dict() + roi_data['rsa']=roi_rsa + roi_data['sample']=roi_sample + roi_data['feature']=roi_feature + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info + roi_name + "_ROIs_data_RSA_Ori" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + #pot results + # #1 scoring methods with accuracy score + fname_fig = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + ) + Plot_RSA(cT_rsa, cT_sample, roi_name,fname_fig) + + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D07_ROI_MVPA_RSA_ID.py b/roi_mvpa/D07_ROI_MVPA_RSA_ID.py new file mode 100644 index 0000000..4ac7b06 --- /dev/null +++ b/roi_mvpa/D07_ROI_MVPA_RSA_ID.py @@ -0,0 +1,316 @@ + +""" +==================== +D07. RSA for MEG on source space of ROI +Identification RSA +==================== +@author: ling liu ling.liu@pku.edu.cn + +decoding methods: CTWCD: Cross Time Within Condition Decoding +classifier: SVM (linear) +feature: spatial pattern (S) + +""" + +import os.path as op +import pickle + + +import matplotlib.pyplot as plt +import mne +import numpy as np +import matplotlib as mpl + + +import argparse + +import os +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from scipy.ndimage import gaussian_filter + + +from rsa_helper_functions_meg import all_to_all_within_class_dist + +from D_MEG_function import set_path_ROI_MVPA, ATdata,sensor_data_for_ROI_MVPA_ID +from D_MEG_function import source_data_for_ROI_MVPA,sub_ROI_for_ROI_MVPA + + + +####if need pop-up figures +# %matplotlib qt5 +#mpl.use('Qt5Agg') + +parser=argparse.ArgumentParser() +parser.add_argument('--sub',type=str,default='SA101',help='subject_id') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT',type=str,nargs='*', default=['1000ms','1500ms'], help='condition in Time duration') +parser.add_argument('--cC',type=str,nargs='*', default=['F'], + help='selected decoding category, FO for face and object, LF for letter and false,' + 'F for face ,O for object, L for letter, FA for false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') +parser.add_argument('--metric', + type=str, + default="correlation", + help='methods for calculate the distance for RDM') + +# parser.add_argument('--coreg_path', +# type=str, +# default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', +# help='Path to the coreg (derivative) directory') + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA +# ============================================================================= +# SESSION-SPECIFIC SETTINGS +# ============================================================================= + + + +subject_id = opt.sub + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +metric=opt.metric + + + +def Identity_RSA(epochs, stcs, n_features=None, metric="correlation", n_jobs=-1, feat_sel_diag=True): + # Find the indices of the relevant conditions: + # epochs.metadata["order"] = list(range(len(epochs.metadata))) + # meta_data = epochs.metadata + # trials_inds = meta_data.loc[((meta_data["Category"] == conditions_C[0]) | (meta_data["Category"] == conditions_C[1])) & + # (meta_data["Task_relevance"] == conD), "order"].to_list() + # # Extract these trials: + # epochs = epochs[trials_inds] + # # Aggregate single trials stcs into a numpy array: + # X = np.array([stc.data for stc in stcs]) + # # Select the trials of interest: + # X = X[trials_inds, :, :] + # # Get the labels: + # y = epochs.metadata["Category"].to_numpy() + + temp = epochs.events[:, 2] # identity label, e.g face01, face02, + + y = temp + X=np.array([stc.data for stc in stcs]) + + + data=ATdata(X) + + label=y + + + + + #metric='euclidean' + + rsa_results, rdm_diag, sel_features = all_to_all_within_class_dist(data,label,metric=metric, + n_bootsstrap=20, + shuffle_labels=False, + fisher_transform=True, + verbose=True, + n_features=n_features, + n_folds=None, + feat_sel_diag=True) + + return rsa_results, rdm_diag, sel_features + + + + + +def Plot_RSA(rsa, roi_name,fname_fig): + + fig, ax = plt.subplots(1) + plt.subplots_adjust(wspace=0.5, hspace=0) + fig.suptitle(f'RSA_Cat_ {roi_name}') + time_point = np.array(range(-500,1501, 10))/1000 + t = time_point + #pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + cmap = mpl.cm.jet + cmap = mpl.cm.jet + im=ax.imshow(gaussian_filter(rsa,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap,extent=t[[0, -1, 0, -1]]) + #im=plt.imshow(gaussian_filter(rsa,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap,extent=t[[0, -1, 0, -1]]) + ax.axhline(0, color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='time (s)', ylabel='time (s)') + plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + mne.viz.tight_layout() + # Save figure + + fig.savefig(op.join(fname_fig+ "_rsa_ID" + '.png')) + + # fig, ax = plt.subplots(1) + + # for condi, Si_name in sample.items(): + # trial_index= np.array(range(0, Si_name.shape[0], 1)) + # #GAT setting + # cmap = mpl.cm.jet + # im=ax.imshow(gaussian_filter(Si_name,sigma=4), interpolation='lanczos', origin='lower', cmap=cmap, + # extent=trial_index[[0, -1, 0, -1]])#, vmin=vmin, vmax=vmax)#, norm=norm + # ax.axhline(0,color='k') + # ax.axvline(0, color='k') + # ax.legend(loc='upper right') + # ax.set_title(f'Sample_RDM_ {roi_name}') + # ax.set(xlabel='First', ylabel='Second') + # plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + # mne.viz.tight_layout() + # # Save figure + + # fig.savefig(op.join(fname_fig+ "_sample_rdm_ID" + '.png')) + + +# ============================================================================= +# RUN +# ============================================================================= + + +# run roi decoding analysis + +if __name__ == "__main__": + + #opt INFO + + # subject_id = 'SB085' + # + # visit_id = 'V1' + # space = 'surface' + # + + # analysis info + + # con_C = ['LF'] + # con_D = ['Irrelevant', 'Relevant non-target'] + # con_T = ['500ms','1000ms','1500ms'] + #metric="correlation" or metric='euclidean' + + analysis_name='RSA_ID' + + # 1 Set Path + sub_info, \ + fpath_epo, fpath_fw, fpath_fs, \ + roi_data_root, roi_figure_root, roi_code_root = set_path_ROI_MVPA(bids_root, + subject_id, + visit_id, + analysis_name) + + # 2 Get Sub ROI + surf_label_list, ROI_Name = sub_ROI_for_ROI_MVPA(fpath_fs, subject_id,analysis_name) + + # 3 prepare the sensor data + epochs_rs, \ + rank, common_cov, \ + conditions_C, conditions_D, conditions_T, task_info = sensor_data_for_ROI_MVPA_ID(fpath_epo, + sub_info, + con_T, + con_C, + con_D, + remove_too_few_trials=True) + + + + roi_rsa = dict() + roi_sample = dict() + roi_feature = dict() + + + for nroi, roi_name in enumerate(ROI_Name): + + # 4 Get Source Data for each ROI + stcs = [] + stcs = source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label_list[nroi]) + + ### CTCCD + + # #1 scoring methods with accuracy score + fname_fig = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + ) + + if roi_name=='GNW': + sample_times=[0.3, 0.5] + else: + sample_times=[0.3, 1.5] + + rsa, sample, sel_features=Identity_RSA(epochs_rs,stcs,metric=metric) + + + + + roi_rsa[roi_name]=rsa + roi_sample[roi_name] =sample + roi_feature[roi_name] = sel_features + + roi_data=dict() + roi_data['rsa']=roi_rsa + roi_data['sample']=roi_sample + roi_data['feature']=roi_feature + + + fname_data=op.join(roi_data_root, sub_info + '_' + task_info + roi_name + "_ROIs_data_RSA_ID" + '.pickle') + fw = open(fname_data,'wb') + pickle.dump(roi_data,fw) + fw.close() + + #pot results + # #1 scoring methods with accuracy score + fname_fig = op.join(roi_figure_root, + sub_info + task_info + '_'+ roi_name + ) + Plot_RSA(rsa, roi_name,fname_fig) + + + + # #load + # fr=open(fname_data,'rb') + # d2=pickle.load(fr) + # fr.close() + + # stc_mean=stc_feat_b.copy().crop(tmin=0, tmax=0.5).mean() + # brain_mean = stc_mean.plot(views='lateral',subject=f'sub-{subject_id}',hemi='lh',size=(800,400),subjects_dir=subjects_dir) + + + +# Save code +# shutil.copy(__file__, roi_code_root) diff --git a/roi_mvpa/D98_group_stat_sROI_plot.py b/roi_mvpa/D98_group_stat_sROI_plot.py new file mode 100644 index 0000000..d218a67 --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot.py @@ -0,0 +1,1053 @@ +""" +==================== +D98. Group analysis for decoding +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os +import argparse + +import pickle +import mne + + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +sns.set_theme(style='ticks') + +from mne.stats import fdr_correction + + +from scipy import stats as stats + + + +import matplotlib as mpl + + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root,plot_param +from sublist import sub_list + + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms', '1000ms', '1500ms'], + help='condition in Time duration') + +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--analysis', + type=str, + default='Cat', + help='the name for anlaysis, e.g. Cat or Ori or GAT_Cat') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_T = opt.cT + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + +# get the parameters dictionary +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +plt.rc('font', size=8) # controls default text size +plt.rc('axes', labelsize=20) +plt.rc('xtick',labelsize=18) +plt.rc('ytick',labelsize=18) +plt.rc('xtick.major', width=2, size=4) +plt.rc('ytick.major', width=2, size=4) +plt.rc('legend', fontsize=18) +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} +mpl.rcParams.update(new_rc_params) + +def mm2inch(val): + return val / 25.4 + +# Color parameters: +cmap = "RdYlBu_r" + + +time_point = np.array(range(-200,2001, 10))/1000 +# set the path for decoding analysis +def set_path_plot(bids_root, visit_id, analysis_name,con_name): + + ### I Set the group Data Path + # Set path to decoding derivatives + decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + + data_path=op.join(decoding_path, analysis_name) + + # Set path to group analysis derivatives + group_deriv_root = op.join(data_path, "group") + if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + # Set path to the ROI MVPA output(1) stat_data, 2) figures, 3) codes) + + # 1) output_stat_data + stat_data_root = op.join(group_deriv_root,"stat_data",con_name) + if not op.exists(stat_data_root): + os.makedirs(stat_data_root) + + # 2) output_figure + stat_figure_root = op.join(group_deriv_root,"stat_figures",con_name) + if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + return group_deriv_root,stat_data_root,stat_figure_root + +def df_plot(ts_df,T1,pval1,T2,pval2,time_point,test_win_on,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + #plot with sns + + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #leg = g._legend + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + reject_fdr1, pval_fdr1 = fdr_correction(pval1, alpha=0.05, method='indep') + temp=reject_fdr1.nonzero() + sig1=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr1 = np.min(np.abs(T1)[reject_fdr1]) + T11=np.concatenate((np.zeros((test_win_on-30,)),T1)) + clusters1 = np.where(T11 > threshold_fdr1)[0] + if len(clusters1)>1: + clusters1 = clusters1[clusters1 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters1], np.zeros(clusters1.shape) + 40, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=1 + + reject_fdr2, pval_fdr2 = fdr_correction(pval2, alpha=0.05, method='indep') + temp=reject_fdr2.nonzero() + sig2=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr2 = np.min(np.abs(T2)[reject_fdr2]) + T22=np.concatenate((np.zeros((test_win_on-30,)),T2)) + clusters2 = np.where(T22 > threshold_fdr2)[0] + if len(clusters2)>1: + clusters2 = clusters2[clusters2 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters2], np.zeros(clusters2.shape) + 30, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=1 + + #plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([15,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([20,40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1, sig2 + +def df_plot_cluster(ts_df,C1_stat,C2_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #leg = g._legend + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + 40, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + temp2=C2_stat['cluster'] + temp_p2=C2_stat['cluster_p'] + sig2=np.full(time_point.shape,np.nan) + if len(temp2)>=1: + for i in range(len(temp2)): + if temp_p2[i]<0.05:# plot the cluster which p < 0.05 + clusters2=temp2[i][0] + plt.plot(time_index[clusters2], np.zeros(clusters2.shape) + 30, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=i + + + + #plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([15,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([20,40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1, sig2 + +def df_plot_cluster_ori(ts_df,C1_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + + # leg = g._legend + # leg.remove() + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + 30, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + + + # plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([25,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([40,60,80,100]) + + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1 + +def df_plot_ROI_cluster(ts_df,C1_stat,time_point,test_win_on,test_win_off,chance_index,y_index,fname_fig): + + window=[0.3,1.5,1.5,0.3] + + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('talk',rc=talk_rc,font_scale=1) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='ROI',aspect=2,palette=colors,legend=True) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #sns.move_legend(g, "upper left", bbox_to_anchor=(.72, .8), frameon=False) + #leg = g._legend + # leg.remove() + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + chance_index-5, 'o', linewidth=3,color=colors['IIT']) + sig1[clusters1]=i + + + + #plt.fill(window,[40,40,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([25,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1 + +def g2gdat(roi_g,time_point,sig1,sig2): + roi_g_acc=np.mean(roi_g[:,:,30:251],axis=1) + roi_g_ci=1.96*stats.sem(roi_g[:,:,30:251],axis=1) + roi_g_dat=np.vstack((time_point,roi_g_acc,roi_g_ci,sig1,sig2)) + return roi_g_dat + +def g2gdat_ori(roi_g,time_point,sig1): + roi_g_acc=np.mean(roi_g[:,:,30:251],axis=1) + roi_g_ci=1.96*stats.sem(roi_g[:,:,30:251],axis=1) + roi_g_dat=np.vstack((time_point,roi_g_acc,roi_g_ci,sig1)) + return roi_g_dat + +def df2csv(np_data,task_index,csv_fname): + columns_index=['Time', + 'ACC (' + task_index[0] + ')','ACC (' + task_index[1] + ')', + 'CI (' + task_index[0] + ')','CI (' + task_index[1] + ')', + 'sig (' + task_index[0] + ')','sig (' + task_index[1] + ')'] + df = pd.DataFrame(np_data.T, columns=columns_index) + df.to_csv(csv_fname,sep=',',index=False,header=True,na_rep='NaN') + +def df2csv_ori(np_data,task_index,csv_fname): + columns_index=['Time', + 'ACC (' + task_index[0] + ')', + 'CI (' + task_index[0] + ')', + 'sig (' + task_index[0] + ')'] + df = pd.DataFrame(np_data.T, columns=columns_index) + df.to_csv(csv_fname,sep=',',index=False,header=True,na_rep='NaN') + +def gc2df(gc_mean,test_win_on,test_win_off,task_index,chance_index): + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T1, pval1 = stats.ttest_1samp(gc_mean[0,:,test_win_on:test_win_off], chance_index) + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T2, pval2 = stats.ttest_1samp(gc_mean[1,:,test_win_on:test_win_off], chance_index) + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,T1,pval1,T2,pval2 + +def stat_cluster_1sample(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T_obs_2, clusters_2, cluster_p_values_2, H0_2 = mne.stats.permutation_cluster_1samp_test( + gc_mean[1,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C2_stat=dict() + C2_stat['T_obs']=T_obs_2 + C2_stat['cluster']=clusters_2 + C2_stat['cluster_p']=cluster_p_values_2 + + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,C1_stat,C2_stat + +def stat_cluster_1sample_ori(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + + ts_df = pd.melt(df1, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,C1_stat + + +def stat_cluster_1sample_roi(ROI1_data,ROI2_data,test_win_on,test_win_off,ROI_name): + + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=ROI1_data.shape[1] + + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(ROI1_data[:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='ROI',value=ROI_name[0]) + + + + df2 = pd.DataFrame(ROI2_data[:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='ROI',value=ROI_name[1]) + + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','ROI'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_test( + [ROI1_data[:,test_win_on:test_win_off] , ROI2_data[:,test_win_on:test_win_off]], + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return ts_df,C1_stat + + +def dat2g(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=dat[sbn][decoding_name][roi_name][cond] + + + roi_ccd_g[ci,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + +def dat2g_PFC(dat,cond_name): + roi_wcd_g=np.zeros([3,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_wcd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_wcd_gc[i,:]=dat[sbn][cond] + roi_wcd_g[ci,:,:]=roi_wcd_gc*100 + + return roi_wcd_g + + +def dat2g_ori(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([1,len(sub_list),251]) + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=dat[sbn][decoding_name][roi_name][cond_name] + roi_ccd_g[0,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + + +def dat2gat(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251,251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251,251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:,:]=np.diagonal(dat[sbn][decoding_name][roi_name][cond]) + + + roi_ccd_g[ci,:,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + +def dat2gat2(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=np.diagonal(dat[sbn][decoding_name][roi_name][cond]) + + + roi_ccd_g[ci,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + + +def ccd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Relevant to Irrelevant','Irrelevant to Relevant'] + #get decoding data + ROI_ccd_g=dat2g(group_data,roi_name,cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_CCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat,C2_stat=stat_cluster_1sample(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_CCD_cluster" + '.svg') + + #plot + sig1_cluster,sig2_cluster=df_plot_cluster(ts_df_cluster,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #prepare data for plt plot + ROI_ccd_g_dat=g2gdat(ROI_ccd_g,time_point,sig1_cluster,sig2_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_CCD_cluster" + '.csv') + + df2csv(ROI_ccd_g_dat,task_index,csv_fname) + + +def wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI_wcd_g=dat2g(group_data,roi_name,cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat,C2_stat=stat_cluster_1sample(ROI_wcd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_cluster" + '.svg') + + #plot + sig1_cluster,sig2_cluster=df_plot_cluster(ts_df_cluster,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + #prepare data for plt plot + ROI_wcd_g_dat=g2gdat(ROI_wcd_g,time_point,sig1_cluster,sig2_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_cluster" + '.csv') + + df2csv(ROI_wcd_g_dat,task_index,csv_fname) + +def ROI_wcd_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI1_data=dat2g(group_data,ROI_name[0],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + ROI2_data=dat2g(group_data,ROI_name[1],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[0], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(ROI1_data[1,:,:],ROI2_data[1,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig2= op.join(stat_figure_root, task_index[1] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig2) + + + +def ROI_ccd_plt(group_data,decoding_method ='ccd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Relevant to Irrelevant','Irrelevant to Relevant'] + #get decoding data + ROI1_data=dat2g(group_data,ROI_name[0],cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + ROI2_data=dat2g(group_data,ROI_name[1],cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(ROI1_data[1,:,:],ROI2_data[1,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig2= op.join(stat_figure_root, task_index[1] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig2) + + + +def wcd_ori_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=conditions_C #Face/Object/Letter/False + #get decoding data + ROI_ori_g=dat2g_ori(group_data,roi_name,cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat=stat_cluster_1sample_ori(ROI_ori_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_ori_cluster" + '.svg') + + #plot + sig1_cluster=df_plot_cluster_ori(ts_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + #prepare data for plt plot + ROI_ori_g_dat=g2gdat_ori(ROI_ori_g,time_point,sig1_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_ori_cluster" + '.csv') + + df2csv_ori(ROI_ori_g_dat,task_index,csv_fname) + + +def ROI_wcd_ori_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI1_data=dat2g_ori(group_data,ROI_name[0],cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + ROI2_data=dat2g_ori(group_data,ROI_name[1],cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[0], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + + +######### +#set data root +group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_name,con_C[0]) + + +# ######## +# #debug +# decoding_path=op.join(bids_root, "derivatives",'decoding') + +# data_path=op.join(decoding_path, analysis_name) + +# # Set path to group analysis derivatives +# group_deriv_root = op.join(data_path, "group") +# if not op.exists(group_deriv_root): +# os.makedirs(group_deriv_root) + + + + +# analysis/task info +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') + +fr=open(fname_data,'rb') +group_data=pickle.load(fr) + + + +if analysis_name=='Cat' or analysis_name=='Cat_offset_control': + #CCD: cross condition decoding + #GNW + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + ccd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=251,chance_index=50,y_index=40) + + # 0ms to 1500ms + ccd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + + #WCD: within condition decoding + #GNW + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=251,chance_index=50,y_index=40) + + # 0ms to 1500ms + wcd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #compare IIT with IIT+GNW(FP) + ROI_ccd_plt(group_data,decoding_method ='ccd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + ROI_wcd_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + +elif analysis_name=='Cat_MT_control': + ccd_plt(group_data,roi_name='MT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + wcd_plt(group_data,roi_name='MT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + +elif analysis_name=='Cat_baseline': + + wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + wcd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + +elif analysis_name=='Ori': + + wcd_ori_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) + wcd_ori_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) + +elif analysis_name=='Cat_PFC': + cond_name=['IIT','IITPFC_f','IITPFC_m'] + colors = { + "IIT": [1,0,0 + ], + "IITPFC_f": [0,0,1 + ], + "IITPFC_m": [0,0,1 + ]} + decoding_method=analysis_name + #task_index=['Irrelevant','Relevant non-target'] + #get decoding data + PFC_data=dat2g_PFC(group_data,cond_name) + + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + test_win_on=50 + test_win_off=200 + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[1,:,:],test_win_on,test_win_off,['IIT','IITPFC_f']) + + fname_cluster_fig= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_feature_diff_acc_cluster.svg') + + # fname_cluster_fig= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_feature_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + chance_index=50,y_index=50, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[2,:,:],test_win_on,test_win_off,['IIT','IITPFC_m']) + + fname_cluster_fig2= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_model_diff_acc_cluster.svg') + + # fname_cluster_fig2= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_model_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + chance_index=50,y_index=50, + fname_fig=fname_cluster_fig2) + + +elif analysis_name=='Ori_PFC': + cond_name=['IIT','IITPFC_f','IITPFC_m'] + colors = { + "IIT": [1,0,0 + ], + "IITPFC_f": [0,0,1 + ], + "IITPFC_m": [0,0,1 + ]} + decoding_method=analysis_name + #task_index=['Irrelevant','Relevant non-target'] + #get decoding data + PFC_data=dat2g_PFC(group_data,cond_name) + + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + test_win_on=50 + test_win_off=200 + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[1,:,:],test_win_on,test_win_off,['IIT','IITPFC_f']) + + fname_cluster_fig= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_feature_diff_acc_cluster.svg') + + # fname_cluster_fig= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_feature_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + chance_index=33.3,y_index=50, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[2,:,:],test_win_on,test_win_off,['IIT','IITPFC_m']) + + fname_cluster_fig2= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_model_diff_acc_cluster.svg') + + # fname_cluster_fig2= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_model_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + chance_index=33.3,y_index=50, + fname_fig=fname_cluster_fig2) + #ROI_wcd_ori_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) diff --git a/roi_mvpa/D98_group_stat_sROI_plot_GAT.py b/roi_mvpa/D98_group_stat_sROI_plot_GAT.py new file mode 100644 index 0000000..4f73896 --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot_GAT.py @@ -0,0 +1,603 @@ +""" +==================== +D98. Group analysis for decoding +genelaization across time (GAT) +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os +import argparse + +import pickle +import mne + + +import numpy as np +import matplotlib.pyplot as plt + +import seaborn as sns +sns.set_theme(style='ticks') + + +from scipy import stats as stats + +from scipy.ndimage import gaussian_filter + +import matplotlib.patheffects as path_effects + +import matplotlib.colors as mcolors + + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from sublist import sub_list + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms', '1000ms', '1500ms'], + help='condition in Time duration') + +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--methods', + type=str, + default='roi_mvpa', + help='decoding methods name, for the data folder') +parser.add_argument('--analysis', + type=str, + default='GAT_Cat', + help='the name for anlaysis, e.g. Tall for 3 durations combined analysis') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +methods_name=opt.methods +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + + + +# Color parameters: +cmap = "RdYlBu_r" +#color_blind_palette = sns.color_palette("colorblind") +colors = { + "IIT": [ + 0.00392156862745098, + 0.45098039215686275, + 0.6980392156862745 + ], + "GNW": [ + 0.00784313725490196, + 0.6196078431372549, + 0.45098039215686275 + ], + "MT": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "FP": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "Relevant to Irrelevant": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Irrelevant to Relevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "Relevant non-target": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Irrelevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], +} + + +time_point = np.array(range(-200,2001, 10))/1000 +# set the path for decoding analysis +def set_path_plot(bids_root, visit_id, analysis_name, con_name): + + ### I Set the group Data Path + # Set path to decoding derivatives + decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + + data_path=op.join(decoding_path,analysis_name) + + # Set path to group analysis derivatives + group_deriv_root = op.join(data_path, "group") + if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + # Set path to the ROI MVPA output(1) stat_data, 2) figures, 3) codes) + + # 1) output_stat_data + stat_data_root = op.join(group_deriv_root,"stat_data",con_name) + if not op.exists(stat_data_root): + os.makedirs(stat_data_root) + + # 2) output_figure + stat_figure_root = op.join(group_deriv_root,"stat_figures",con_name) + if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + return group_deriv_root,stat_data_root,stat_figure_root + + + +def df_plot_cluster_GAT(gc_mean,C1_stat,C2_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + + talk_rc={'lines.linewidth':1,'lines.markersize':1} + sns.set_context('paper',rc=talk_rc,font_scale=4) + + fig, axes = plt.subplots(1, 1,figsize=(10,10),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + + + t = time_point + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + #cmap = mpl.cm.RdYlBu_r + cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', + np.vstack((plt.cm.Blues_r(np.linspace(0, 1, 220) ), + plt.cm.Blues_r( np.linspace(1, 1, 36) ), + plt.cm.Reds( np.linspace(0, 0, 36) ), + plt.cm.Reds( np.linspace(0, 1, 220) ) ) ) ) + vmin = 0 + vmax = 100 + # bounds = np.linspace(vmin, vmax, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + #plot + GAT_avg=np.mean(gc_mean[0,:,30:251,30:251],axis=0) + GAT_avg_plot = np.nan * np.ones_like(GAT_avg) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + GAT_avg_plot[c] = GAT_avg[c] + + im = axes.imshow(gaussian_filter(GAT_avg,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap,alpha=0.9, + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.contour(GAT_avg_plot > 0, GAT_avg_plot > 0, colors="black", linewidths=1.5, origin="lower",extent=t[[0, -1, 0, -1]]) + im = axes.imshow(GAT_avg_plot, origin='lower', cmap=cmap,aspect='auto', + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.set_xlabel('Testing Time (s)') + axes.set_ylabel('Training Time (s)') + axes.set_xticks([0,1,2]) + axes.set_yticks([0,1,2]) + axes.set_title(task_index[0]) + axes.axvline(0, color='k',linestyle='--') + axes.axhline(0, color='k',linestyle='--') + axes.axline((0, 0), slope=1, color='k',linestyle='--') + plt.colorbar(im, ax=axes,fraction=0.03, pad=0.05) + cb = axes.figure.axes[-1] + m = axes.figure.axes[-2] + pos = m.get_position().bounds + cb.set_position([pos[2]+pos[0]+0.01, pos[1], 0.1, pos[3]]) + + + + fname_fig_1=op.join(fname_fig+'_'+task_index[0]+'.svg') + + fig.savefig(fname_fig_1,format="svg") + + + talk_rc={'lines.linewidth':1,'lines.markersize':1} + sns.set_context('paper',rc=talk_rc,font_scale=4) + + + fig, axes = plt.subplots(1, 1,figsize=(10,10),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + + t = time_point + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + #cmap = mpl.cm.RdYlBu_r + + cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', + np.vstack((plt.cm.Blues_r(np.linspace(0, 1, 220) ), + plt.cm.Blues_r( np.linspace(1, 1, 36) ), + plt.cm.Reds( np.linspace(0, 0, 36) ), + plt.cm.Reds( np.linspace(0, 1, 220) ) ) ) ) + vmin = 0 + vmax = 100 + # bounds = np.linspace(vmin, vmax, 11) + + GAT2_avg=np.mean(gc_mean[1,:,30:251,30:251],axis=0) + GAT2_avg_plot = np.nan * np.ones_like(GAT_avg) + for c2, p2_val in zip(C2_stat['cluster'], C2_stat['cluster_p']): + if p2_val <= 0.05: + GAT2_avg_plot[c2] = GAT2_avg[c2] + im = axes.imshow(gaussian_filter(np.mean(gc_mean[1,:,30:251,30:251],axis=0), sigma=2), interpolation='lanczos', origin='lower', cmap=cmap,alpha=0.9, + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.contour(GAT2_avg_plot > 0, GAT2_avg_plot > 0, colors="black", linewidths=1.5, origin="lower",extent=t[[0, -1, 0, -1]]) + im = axes.imshow(GAT2_avg_plot, origin='lower', cmap=cmap,aspect='auto', + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.set_xlabel('Testing Time (s)') + axes.set_ylabel('Training Time (s)') + axes.set_xticks([0,1,2]) + axes.set_yticks([0,1,2]) + axes.set_title(task_index[1]) + axes.axvline(0, color='k',linestyle='--') + axes.axhline(0, color='k',linestyle='--') + axes.axline((0,0), slope=1, color='k',linestyle='--') + plt.colorbar(im, ax=axes,fraction=0.03, pad=0.05) + cb = axes.figure.axes[-1] + m = axes.figure.axes[-2] + pos = m.get_position().bounds + cb.set_position([pos[2]+pos[0]+0.01, pos[1], 0.1, pos[3]]) + + fname_fig_2=op.join(fname_fig+'_'+task_index[1]+'.svg') + + fig.savefig(fname_fig_2,format="svg") + + +def df_plot_cluster_GAT_ori(gc_mean,C1_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + + + fig, axes = plt.subplots(1, 1,figsize=(5,5),sharex=True,sharey=True) + plt.subplots_adjust(wspace=0.5, hspace=0) + + + t = time_point + pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), path_effects.Normal()] + + #cmap = mpl.cm.RdYlBu_r + cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', + np.vstack((plt.cm.Blues_r(np.linspace(0, 1, 220) ), + plt.cm.Blues_r( np.linspace(1, 1, 36) ), + plt.cm.Reds( np.linspace(0, 0, 36) ), + plt.cm.Reds( np.linspace(0, 1, 220) ) ) ) ) + vmin = 0 + vmax = 100 + # bounds = np.linspace(0, 100, 11) + # norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + #plot + GAT_avg=np.mean(gc_mean[0,:,30:251,30:251],axis=0) + GAT_avg_plot = np.nan * np.ones_like(GAT_avg) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + GAT_avg_plot[c] = GAT_avg[c] + + im = axes.imshow(gaussian_filter(GAT_avg,sigma=2), interpolation='lanczos', origin='lower', cmap=cmap,alpha=0.9, + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.contour(GAT_avg_plot > 0, GAT_avg_plot > 0, colors="black", linewidths=1.5, origin="lower",extent=t[[0, -1, 0, -1]]) + im = axes.imshow(GAT_avg_plot, origin='lower', cmap=cmap,aspect='auto', + extent=t[[0, -1, 0, -1]], vmin=vmin, vmax=vmax) + axes.set_xlabel('Testing Time (s)') + axes.set_ylabel('Training Time (s)') + axes.set_title('GAT_Ori') + axes.axvline(0, color='k',linestyle='--') + axes.axhline(0, color='k',linestyle='--') + axes.axline((0, 0), slope=1, color='k',linestyle='--') + plt.colorbar(im, ax=axes,fraction=0.03, pad=0.05) + cb = axes.figure.axes[-1] + m = axes.figure.axes[-2] + pos = m.get_position().bounds + cb.set_position([pos[2]+pos[0]+0.01, pos[1], 0.1, pos[3]]) + + talk_rc={'lines.linewidth':1,'lines.markersize':1} + sns.set_context('paper',rc=talk_rc,font_scale=4) + + fig.savefig(fname_fig,format="svg") + + + +def stat_cluster_1sample_GAT(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points,stat_time_points])*chance_index, + threshold=thresh, n_permutations=1000, tail=tail, out_type='mask',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + + T_obs_2, clusters_2, cluster_p_values_2, H0_2 = mne.stats.permutation_cluster_1samp_test( + gc_mean[1,:,test_win_on:test_win_off,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points,stat_time_points])*chance_index, + threshold=thresh, n_permutations=1000, tail=tail, out_type='indices',verbose=None) + + C2_stat=dict() + C2_stat['T_obs']=T_obs_2 + C2_stat['cluster']=clusters_2 + C2_stat['cluster_p']=cluster_p_values_2 + + + return C1_stat,C2_stat + +def stat_cluster_1sample_GAT_ori(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points,stat_time_points])*chance_index, + threshold=thresh, n_permutations=1000, tail=tail, out_type='mask',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return C1_stat + + + + +def dat2gat(dat,roi_name,cond_name,decoding_name): + roi_gat=np.zeros([2,len(sub_list),251,251]) + for ci, cond in enumerate(cond_name): + roi_gat_gc=np.zeros([len(sub_list),251,251]) + for i, sbn in enumerate(sub_list): + roi_gat_gc[i,:,:]=dat[sbn][decoding_name][roi_name][cond] + + + roi_gat[ci,:,:,:]=roi_gat_gc*100 + + return roi_gat + +def dat2gat_ori(dat,roi_name,cond_name,decoding_name): + roi_gat=np.zeros([1,len(sub_list),251,251]) + #for ci, cond in enumerate(cond_name): + roi_gat_gc=np.zeros([len(sub_list),251,251]) + for i, sbn in enumerate(sub_list): + roi_gat_gc[i,:,:]=dat[sbn][decoding_name][roi_name][cond_name] + + + roi_gat[0,:,:,:]=roi_gat_gc*100 + + return roi_gat + + +def ctccd_plt(group_data,con_name,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Relevant to Irrelevant','Irrelevant to Relevant'] + #get decoding data + ROI_gat_g=dat2gat(group_data,roi_name,cond_name=['RE2IR','IR2RE'],decoding_name='ctccd_acc') + + + #cluster based methods + + #stat + C1_stat,C2_stat=stat_cluster_1sample_GAT(ROI_gat_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig_index= op.join(stat_figure_root, roi_name + '_'+str(con_name)+"_acc_CTCCD_cluster" ) + + #plot + df_plot_cluster_GAT(ROI_gat_g,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig_index) + + +def ctwcd_plt(group_data,con_name,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI_gat_g=dat2gat(group_data,roi_name,cond_name=['Irrelevant','Relevant non-target'],decoding_name='ctwcd_acc') + + + + #cluster based methods + + #stat + C1_stat,C2_stat=stat_cluster_1sample_GAT(ROI_gat_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(con_name)+"_acc_CTWCD_cluster") + + #plot + df_plot_cluster_GAT(ROI_gat_g,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + + + + + +def ctwcd_ori_plt(group_data,con_name,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=conditions_C #Face/Object/Letter/False + #get decoding data + ROI_ori_g=dat2gat_ori(group_data,roi_name,conditions_C[0],decoding_name='ctwcd_ori_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + C1_stat=stat_cluster_1sample_GAT_ori(ROI_ori_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(con_name)+"_acc_CTWCD_ori_cluster" + '.svg') + + #plot + df_plot_cluster_GAT_ori(ROI_ori_g,C1_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + + + +######### +#set data root +group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_name,con_C[0]) + + +# ######## +# #debug +# decoding_path=op.join(bids_root, "derivatives",'decoding') + +# data_path=op.join(decoding_path,methods_name) + +# # Set path to group analysis derivatives +# group_deriv_root = op.join(data_path, "group", analysis_name) +# if not op.exists(group_deriv_root): +# os.makedirs(group_deriv_root) + + + + +# analysis/task info +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') + +fr=open(fname_data,'rb') +group_data=pickle.load(fr) + + + +if analysis_name=='GAT_Cat': + #CCD: cross condition decoding + #GNW + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + ctccd_plt(group_data,con_Tname,roi_name='GNW',test_win_on=30, test_win_off=251,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=250,chance_index=50,y_index=40) + + # 0ms to 1500ms + ctccd_plt(group_data,con_Tname,roi_name='IIT',test_win_on=30, test_win_off=251,chance_index=50,y_index=40) + + + #WCD: within condition decoding + #GNW + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + ctwcd_plt(group_data,con_Tname,roi_name='GNW',test_win_on=30, test_win_off=251,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=250,chance_index=50,y_index=40) + + # 0ms to 1500ms + ctwcd_plt(group_data,con_Tname,roi_name='IIT',test_win_on=30, test_win_off=251,chance_index=50,y_index=40) + + +elif analysis_name=='GAT_Ori': + + ctwcd_ori_plt(group_data,con_Tname,roi_name='GNW',test_win_on=30, test_win_off=251,chance_index=33.3,y_index=40) + ctwcd_ori_plt(group_data,con_Tname,roi_name='IIT',test_win_on=30, test_win_off=251,chance_index=33.3,y_index=40) + diff --git a/roi_mvpa/D98_group_stat_sROI_plot_RSA.py b/roi_mvpa/D98_group_stat_sROI_plot_RSA.py new file mode 100644 index 0000000..3f8b802 --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot_RSA.py @@ -0,0 +1,626 @@ +""" +==================== +D08. Group analysis for RSA +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os +import argparse + +import pickle +import mne + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +sns.set_theme(style='ticks') + + +from scipy import stats as stats + + +import matplotlib as mpl + +from matplotlib.patches import Rectangle + +from rsa_helper_functions_meg import subsample_matrices,compute_correlation_theories + +import ptitprince as pt + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root,plot_param +from sublist import sub_list + + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms', '1000ms', '1500ms'], + help='condition in Time duration') + +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--methods', + type=str, + default='roi_mvpa', + help='decoding methods name, for the data folder') +parser.add_argument('--analysis', + type=str, + default='RSA_ID', + help='the name for anlaysis, e.g. Tall for 3 durations combined analysis') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +methods_name=opt.methods +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + +# get the parameters dictionary +def mm2inch(val): + return val / 25.4 + +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +# plt.rcParams["font.family"] = "serif" +# plt.rcParams["font.serif"] = "Times New Roman" +plt.rc('font', size=param["font_size"]) # controls default text sizes +plt.rc('axes', titlesize=param["font_size"]) # fontsize of the axes title +plt.rc('axes', labelsize=param["font_size"]) # fontsize of the x and y labels +plt.rc('xtick', labelsize=param["font_size"]) # fontsize of the tick labels +plt.rc('ytick', labelsize=param["font_size"]) # fontsize of the tick labels +plt.rc('legend', fontsize=param["font_size"]) # legend fontsize +plt.rc('figure', titlesize=param["font_size"]) # fontsize of the fi +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} +mpl.rcParams.update(new_rc_params) + + + +# set the path for decoding analysis +def set_path_plot(bids_root, visit_id, analysis_name,con_name): + + ### I Set the group Data Path + # Set path to decoding derivatives + decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + + data_path=op.join(decoding_path, analysis_name) + + # Set path to group analysis derivatives + group_deriv_root = op.join(data_path, "group") + if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + # Set path to the ROI MVPA output(1) stat_data, 2) figures, 3) codes) + + # 1) output_stat_data + stat_data_root = op.join(group_deriv_root,"stat_data",con_name) + if not op.exists(stat_data_root): + os.makedirs(stat_data_root) + + # 2) output_figure + stat_figure_root = op.join(group_deriv_root,"stat_figures",con_name) + if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + return group_deriv_root,stat_data_root,stat_figure_root + + +def rsa2gat(dat,roi_name,cond_name,decoding_name,analysis): + if analysis=='RSA_ID': + time_points=201 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + if analysis=='RSA_Ori': + time_points=251 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + elif analysis=='RSA_Cat': + time_points=251 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name][decoding_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + + return roi_rsa_g + +def rsa_plot(roi_rsa_data,C1_stat,time_points,fname_fig): + + #fig, ax = plt.subplots(1) + fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) + + roi_rsa_mean=np.mean(roi_rsa_data,0) + RDM_avg_plot = np.nan * np.ones_like(roi_rsa_mean) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + RDM_avg_plot[c] = roi_rsa_mean[c] + + cmap = mpl.cm.RdYlBu_r + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, alpha=0.9,aspect='equal', + extent=time_points[[0, -1, 0, -1]],vmin=-0.3, vmax=0.3) + ax.contour(RDM_avg_plot > 0, RDM_avg_plot > 0, colors="grey", linewidths=2, origin="lower",extent=time_points[[0, -1, 0, -1]]) + im = ax.imshow(RDM_avg_plot, origin='lower', cmap=cmap,aspect='equal', + extent=time_points[[0, -1, 0, -1]], vmin=-0.3, vmax=0.3) + + # Define the size and position of the squares + square_size = 0.2 + x=[0.3,0.8,1.3,1.8] + y=[0.3,0.8,1.3,1.8] + squares=[] + for ii in range(16): + for nn in range(4): + for mm in range(4): + squares.append((x[nn],y[mm])) + + + # Draw the squares + for square in squares: + + rect = Rectangle(square, square_size, square_size, linewidth=3,edgecolor=[0, 0, 0], facecolor='none', linestyle=":") + ax.add_patch(rect) + + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + ax.set_yticks([0, 0.5, 1.0, 1.5]) + #plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + cb = plt.colorbar(im, fraction=0.046, pad=0.04) + #cb.ax.set_ylabel(cbar_label) + cb.ax.set_yscale('linear') # To make sure that the spacing is correct despite normalizat + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + #mne.stats.permutation_cluster_1samp_test + +def rsa_ID_plot(roi_rsa_data,C1_stat,time_points,fname_fig): + + #fig, ax = plt.subplots(1) + fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) + + roi_rsa_mean=np.mean(roi_rsa_data,0) + RDM_avg_plot = np.nan * np.ones_like(roi_rsa_mean) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + RDM_avg_plot[c] = roi_rsa_mean[c] + + cmap = mpl.cm.RdYlBu_r + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, alpha=0.9,aspect='equal', + extent=time_points[[0, -1, 0, -1]],vmin=-0.3, vmax=0.3) + ax.contour(RDM_avg_plot > 0, RDM_avg_plot > 0, colors="grey", linewidths=2, origin="lower",extent=time_points[[0, -1, 0, -1]]) + im = ax.imshow(RDM_avg_plot, origin='lower', cmap=cmap,aspect='equal', + extent=time_points[[0, -1, 0, -1]], vmin=-0.3, vmax=0.3) + + # Define the size and position of the squares + square_size = 0.2 + x=[0.3,0.8,1.3] + y=[0.3,0.8,1.3] + squares=[] + for ii in range(9): + for nn in range(3): + for mm in range(3): + squares.append((x[nn],y[mm])) + + + # Draw the squares + for square in squares: + + rect = Rectangle(square, square_size, square_size, linewidth=3,edgecolor=[0, 0, 0], facecolor='none', linestyle=":") + ax.add_patch(rect) + + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + ax.set_yticks([0, 0.5, 1.0, 1.5]) + #plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + cb = plt.colorbar(im, fraction=0.046, pad=0.04) + #cb.ax.set_ylabel(cbar_label) + cb.ax.set_yscale('linear') # To make sure that the spacing is correct despite normalizat + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + #mne.stats.permutation_cluster_1samp_test + +def rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig): + + fig, ax = plt.subplots(1) + + + + cmap = mpl.cm.RdYlBu_r + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, + aspect='equal',vmin=vmin, vmax=vmax) + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + + # Adding the matrices demarcations in case of subsampling: + [ax.axhline(ind + 0.5, color='k', linestyle='--') + for ind in matrices_delimitations_ref] + [ax.axvline(ind + 0.5, color='k', linestyle='--') + for ind in matrices_delimitations_ref] + # Adding axis break to mark the difference: + d = 0.01 + kwargs = dict(transform=ax.transAxes, color='k', clip_on=False) + # Looping through each demarcations to mark them:: + for ind in matrices_delimitations_ref: + ind_trans = (ind + 1) / len(roi_rsa_mean) + ax.plot((ind_trans - 0.005 - d, ind_trans + - 0.005 + d), (-d, +d), **kwargs) + ax.plot((ind_trans + 0.005 - d, ind_trans + + 0.005 + d), (-d, +d), **kwargs) + ax.plot((-d, +d), (ind_trans - 0.005 - d, + ind_trans - 0.005 + d), **kwargs) + ax.plot((-d, +d), (ind_trans + 0.005 - d, + ind_trans + 0.005 + d), **kwargs) + # Generate the ticks: + ticks_pos = np.linspace(0, roi_rsa_mean.shape[0] - 1, 8) + # Generate the tick position and labels: + ticks_labels = [str(subsampled_time_ref[int(ind)]) for ind in ticks_pos] + ax.set_xticks(ticks_pos) + ax.set_yticks(ticks_pos) + ax.set_xticklabels(ticks_labels) + ax.set_yticklabels(ticks_labels) + plt.tight_layout() + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + +# def sign_test(data): +# seed=1999 +# random_state = check_random_state(seed) +# p=np.mean(data * random_state.choice([1, -1], len(data))) +# return p + +def theory_rdm(RSA_methods): + if RSA_methods=='RSA_ID': + GNW_rdm=np.zeros([63,63]) + GNW_rdm[0:21,0:21]=1 + GNW_rdm[0:21,42:63]=1 + GNW_rdm[42:63,0:21]=1 + GNW_rdm[42:63,42:63]=1 + + IIT_rdm=np.zeros([63,63]) + IIT_rdm[0:42,0:42]=1 + + theory_rdm=dict() + theory_rdm['IIT']=IIT_rdm + theory_rdm['GNW']=GNW_rdm + elif RSA_methods=='RSA_Cat'or'RSA_Ori': + GNW_rdm=np.zeros([84,84]) + GNW_rdm[0:21,0:21]=1 + GNW_rdm[0:21,63:84]=1 + GNW_rdm[63:84,0:21]=1 + GNW_rdm[63:84,63:84]=1 + + IIT_rdm=np.zeros([84,84]) + IIT_rdm[0:63,0:63]=1 + + theory_rdm=dict() + theory_rdm['IIT']=IIT_rdm + theory_rdm['GNW']=GNW_rdm + + return theory_rdm + +def corr_theory(rsa_subsample,analysis_name,decoding_name): + + #1:generated theory_rdm + theory_rdm_matrix=theory_rdm(analysis_name) + + #2:correlate the theories matrices with the observed matrices for each subjects + for n in range(len(sub_list)): + observed_matrix=rsa_subsample[n,:,:] + if n==0: + correlation_results, correlation_results_corrected=compute_correlation_theories([observed_matrix], theory_rdm_matrix, method="kendall") + group_corr_corrected=correlation_results_corrected + group_corr=correlation_results + else: + correlation_results, correlation_results_corrected=compute_correlation_theories([observed_matrix], theory_rdm_matrix, method="kendall") + group_corr_corrected=group_corr_corrected.append(correlation_results_corrected,ignore_index=True) + group_corr=group_corr.append(correlation_results,ignore_index=True) + + #stat + + p_value=dict() + stat,p_value['IIT']=stats.wilcoxon(group_corr['IIT']) + stat,p_value['GNW']=stats.wilcoxon(group_corr['GNW']) + stat,p_value['diff']=stats.mannwhitneyu(group_corr_corrected['GNW'],group_corr_corrected['IIT']) + + fname_p_value=op.join(stat_data_root, task_info +"_" + analysis_name + roi_name + decoding_name +'_stat_value.npz') + np.savez(fname_p_value,p_value,group_corr,group_corr_corrected) + + + corr_palette=[colors['IIT'],colors['GNW']] + #plot + group_corr_plot=group_corr.melt(var_name='theory',value_name='corr') + + fig, ax = plt.subplots(1) + ax=pt.RainCloud(x='theory',y='corr',data=group_corr_plot,palette=corr_palette,bw=.2,width_viol=.5,ax=ax,orient='v') + plt.title(analysis_name+'_corr_'+roi_name) + fname_corr_fig=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name+'_'+analysis_name + decoding_name +'_corr.svg') + fig.savefig(fname_corr_fig,format="svg", transparent=True, dpi=300) + +def RSA_ID_plot(roi_name): + analysis_name='RSA_ID' + time_point = np.array(range(-500, 1501, 10))/1000 + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name='ID',analysis=analysis_name) + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=201,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name+'_.svg') + #plot + rsa_ID_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5]]} + + rsa_subsample=np.zeros([len(sub_list),63,63]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 1.5, intervals_of_interest) + + + roi_rsa_mean=np.mean(rsa_subsample,0) + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name +'_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name='ID') + + +def RSA_Cat_plot(roi_name,condition): + analysis_name='RSA_Cat' + + time_point = np.array(range(-500,2001, 10))/1000 + if condition=='Irrelevant': + conD='IR' + elif condition=='Relevant non-target': + conD='RE' + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name=condition,analysis=analysis_name) + + + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=251,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name + '_' + conD +'.svg') + #plot + rsa_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]]} + + rsa_subsample=np.zeros([len(sub_list),84,84]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 2, intervals_of_interest) + + roi_rsa_mean=np.mean(rsa_subsample,0) + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name + '_' + conD + '_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name=condition) + +def RSA_Ori_plot(roi_name): + analysis='RSA_Ori' + time_point = np.array(range(-500,2001, 10))/1000 + + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name='Ori',analysis=analysis) + + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=251,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name +'.svg') + #plot + rsa_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]]} + + rsa_subsample=np.zeros([len(sub_list),84,84]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 2, intervals_of_interest) + + roi_rsa_mean=np.mean(rsa_subsample,0) + + + + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name +'_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name='Ori') + + +def stat_cluster_1sample_RDM(gc_mean,test_win_on,test_win_off,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 1 # two-tailed + n_observations=gc_mean.shape[0] + stat_time_points=gc_mean[:,test_win_on:test_win_off,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[:,test_win_on:test_win_off,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points,stat_time_points])*chance_index, + threshold=thresh, n_permutations=1000, tail=tail, out_type='mask',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return C1_stat + + +################ +#set data root +if analysis_name=='RSA_Cat': + analysis_index='RSA_Cat_NoFS' + group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_index,con_C[0]) +else: + group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_name,con_C[0]) + + +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') + +fr=open(fname_data,'rb') +group_data=pickle.load(fr) + + + + +if analysis_name=='RSA_ID': + #sub_list.remove('SB006') + #sub_list.remove('SB003') + # GNW ROI + roi_name='GNW' + RSA_ID_plot(roi_name) + + # IIT ROI + roi_name='IIT' + RSA_ID_plot(roi_name) + +elif analysis_name=='RSA_Cat': + #sub_list.remove('SB006') + #sub_list.remove('SB003') + # GNW ROI + roi_name='GNW' + condition='Irrelevant' + RSA_Cat_plot(roi_name,condition) + condition='Relevant non-target' + RSA_Cat_plot(roi_name,condition) + + # IIT ROI + roi_name='IIT' + condition='Irrelevant' + RSA_Cat_plot(roi_name,condition) + condition='Relevant non-target' + RSA_Cat_plot(roi_name,condition) + +elif analysis_name=='RSA_Ori': + #sub_list.remove('SB006') + #sub_list.remove('SB003') + # GNW ROI + roi_name='GNW' + RSA_Ori_plot(roi_name) + + # IIT ROI + roi_name='IIT' + RSA_Ori_plot(roi_name) + + + diff --git a/roi_mvpa/D98_group_stat_sROI_plot_RSA_phaseII.py b/roi_mvpa/D98_group_stat_sROI_plot_RSA_phaseII.py new file mode 100644 index 0000000..f82550b --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot_RSA_phaseII.py @@ -0,0 +1,629 @@ +""" +==================== +D08. Group analysis for RSA +for phaseII +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os +import argparse + +import pickle +import mne + + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +sns.set_theme(style='ticks') + + +from scipy import stats as stats + + +import matplotlib as mpl + +from matplotlib.patches import Rectangle + +from rsa_helper_functions_meg import subsample_matrices,compute_correlation_theories + +import ptitprince as pt + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root,plot_param + +import matplotlib.colors as mcolors + +from sublist_phase2 import sub_list + + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms', '1000ms', '1500ms'], + help='condition in Time duration') + +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--methods', + type=str, + default='roi_mvpa', + help='decoding methods name, for the data folder') +parser.add_argument('--analysis', + type=str, + default='RSA_ID', + help='the name for anlaysis, e.g. Tall for 3 durations combined analysis') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +methods_name=opt.methods +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + +# get the parameters dictionary +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +# plt.rcParams["font.family"] = "serif" +# plt.rcParams["font.serif"] = "Times New Roman" +plt.rc('font', size=param["font_size"]) # controls default text sizes +plt.rc('axes', titlesize=param["font_size"]) # fontsize of the axes title +plt.rc('axes', labelsize=param["font_size"]) # fontsize of the x and y labels +plt.rc('xtick', labelsize=param["font_size"]) # fontsize of the tick labels +plt.rc('ytick', labelsize=param["font_size"]) # fontsize of the tick labels +plt.rc('legend', fontsize=param["font_size"]) # legend fontsize +plt.rc('figure', titlesize=param["font_size"]) # fontsize of the fi +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} +mpl.rcParams.update(new_rc_params) + +def mm2inch(val): + return val / 25.4 + +# set the path for decoding analysis +def set_path_plot(bids_root, visit_id, analysis_name,con_name): + + ### I Set the group Data Path + # Set path to decoding derivatives + decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + + data_path=op.join(decoding_path, analysis_name) + + # Set path to group analysis derivatives + group_deriv_root = op.join(data_path, "group_phase2") + if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + # Set path to the ROI MVPA output(1) stat_data, 2) figures, 3) codes) + + # 1) output_stat_data + stat_data_root = op.join(group_deriv_root,"stat_data",con_name) + if not op.exists(stat_data_root): + os.makedirs(stat_data_root) + + # 2) output_figure + stat_figure_root = op.join(group_deriv_root,"stat_figures",con_name) + if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + return group_deriv_root,stat_data_root,stat_figure_root + + +def rsa2gat(dat,roi_name,cond_name,decoding_name,analysis): + if analysis=='RSA_ID': + time_points=201 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + if analysis=='RSA_Ori': + time_points=251 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + elif analysis=='RSA_Cat': + time_points=251 + roi_rsa_g=np.zeros([len(sub_list),time_points,time_points]) + for ci, cond in enumerate(cond_name): + roi_rsa_gc=np.zeros([len(sub_list),time_points,time_points]) + for i, sbn in enumerate(sub_list): + roi_rsa_gc[i,:,:]=dat[sbn][roi_name][cond][roi_name][decoding_name] + roi_rsa_g[:,:,:]=roi_rsa_gc + + return roi_rsa_g + +def rsa_plot(roi_rsa_data,C1_stat,time_points,fname_fig): + + #fig, ax = plt.subplots(1) + fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) + + roi_rsa_mean=np.mean(roi_rsa_data,0) + RDM_avg_plot = np.nan * np.ones_like(roi_rsa_mean) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + RDM_avg_plot[c] = roi_rsa_mean[c] + + cmap = mpl.cm.RdYlBu_r + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, alpha=0.9,aspect='equal', + extent=time_points[[0, -1, 0, -1]],vmin=-0.02, vmax=0.1) + ax.contour(RDM_avg_plot > 0, RDM_avg_plot > 0, colors="grey", linewidths=2, origin="lower",extent=time_points[[0, -1, 0, -1]]) + im = ax.imshow(RDM_avg_plot, origin='lower', cmap=cmap,aspect='equal', + extent=time_points[[0, -1, 0, -1]], vmin=-0.02, vmax=0.1) + + # Define the size and position of the squares + square_size = 0.2 + x=[0.3,0.8,1.3,1.8] + y=[0.3,0.8,1.3,1.8] + squares=[] + for ii in range(16): + for nn in range(4): + for mm in range(4): + squares.append((x[nn],y[mm])) + + + # Draw the squares + for square in squares: + + rect = Rectangle(square, square_size, square_size, linewidth=3,edgecolor=[0, 0, 0], facecolor='none', linestyle=":") + ax.add_patch(rect) + + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + ax.set_yticks([0, 0.5, 1.0, 1.5]) + #plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + cb = plt.colorbar(im, fraction=0.046, pad=0.04) + #cb.ax.set_ylabel(cbar_label) + cb.ax.set_yscale('linear') # To make sure that the spacing is correct despite normalizat + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + #mne.stats.permutation_cluster_1samp_test + +def rsa_ID_plot(roi_rsa_data,C1_stat,time_points,fname_fig): + + #fig, ax = plt.subplots(1) + fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) + + roi_rsa_mean=np.mean(roi_rsa_data,0) + RDM_avg_plot = np.nan * np.ones_like(roi_rsa_mean) + for c, p_val in zip(C1_stat['cluster'], C1_stat['cluster_p']): + if p_val <= 0.05: + RDM_avg_plot[c] = roi_rsa_mean[c] + + #cmap = mpl.cm.RdYlBu_r + cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', + np.vstack((plt.cm.Blues_r(np.linspace(0, 1, 220) ), + plt.cm.Blues_r( np.linspace(1, 1, 36) ), + plt.cm.Reds( np.linspace(0, 0, 36) ), + plt.cm.Reds( np.linspace(0, 1, 220) ) ) ) ) + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, alpha=0.9,aspect='equal', + extent=time_points[[0, -1, 0, -1]],vmin=-0.02, vmax=0.1) + ax.contour(RDM_avg_plot > 0, RDM_avg_plot > 0, colors="grey", linewidths=2, origin="lower",extent=time_points[[0, -1, 0, -1]]) + im = ax.imshow(RDM_avg_plot, origin='lower', cmap=cmap,aspect='equal', + extent=time_points[[0, -1, 0, -1]], vmin=-0.02, vmax=0.1) + + # Define the size and position of the squares + square_size = 0.2 + x=[0.3,0.8,1.3] + y=[0.3,0.8,1.3] + squares=[] + for ii in range(9): + for nn in range(3): + for mm in range(3): + squares.append((x[nn],y[mm])) + + + # Draw the squares + for square in squares: + + rect = Rectangle(square, square_size, square_size, linewidth=3,edgecolor=[0, 0, 0], facecolor='none', linestyle=":") + ax.add_patch(rect) + + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + ax.set_yticks([0, 0.5, 1.0, 1.5]) + #plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + cb = plt.colorbar(im, fraction=0.046, pad=0.04) + #cb.ax.set_ylabel(cbar_label) + cb.ax.set_yscale('linear') # To make sure that the spacing is correct despite normalizat + mne.viz.tight_layout() + # Save figure + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + #mne.stats.permutation_cluster_1samp_test + +def rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig): + + fig, ax = plt.subplots(1) + + + + cmap = mpl.cm.RdYlBu_r + im=ax.imshow(roi_rsa_mean, interpolation='lanczos', origin='lower', cmap=cmap, + aspect='equal',vmin=vmin, vmax=vmax) + ax.axhline(0,color='k') + ax.axvline(0, color='k') + ax.legend(loc='upper right') + #ax.set_title(f'RSA_ {roi_name}') + ax.set(xlabel='Time (s)', ylabel='Time (s)') + ax.set_xticks([0, 0.5, 1.0, 1.5]) + plt.colorbar(im, ax=ax,fraction=0.03, pad=0.05) + + # Adding the matrices demarcations in case of subsampling: + [ax.axhline(ind + 0.5, color='k', linestyle='--') + for ind in matrices_delimitations_ref] + [ax.axvline(ind + 0.5, color='k', linestyle='--') + for ind in matrices_delimitations_ref] + # Adding axis break to mark the difference: + d = 0.01 + kwargs = dict(transform=ax.transAxes, color='k', clip_on=False) + # Looping through each demarcations to mark them:: + for ind in matrices_delimitations_ref: + ind_trans = (ind + 1) / len(roi_rsa_mean) + ax.plot((ind_trans - 0.005 - d, ind_trans + - 0.005 + d), (-d, +d), **kwargs) + ax.plot((ind_trans + 0.005 - d, ind_trans + + 0.005 + d), (-d, +d), **kwargs) + ax.plot((-d, +d), (ind_trans - 0.005 - d, + ind_trans - 0.005 + d), **kwargs) + ax.plot((-d, +d), (ind_trans + 0.005 - d, + ind_trans + 0.005 + d), **kwargs) + # Generate the ticks: + ticks_pos = np.linspace(0, roi_rsa_mean.shape[0] - 1, 8) + # Generate the tick position and labels: + ticks_labels = [str(subsampled_time_ref[int(ind)]) for ind in ticks_pos] + ax.set_xticks(ticks_pos) + ax.set_yticks(ticks_pos) + ax.set_xticklabels(ticks_labels) + ax.set_yticklabels(ticks_labels) + plt.tight_layout() + + fig.savefig(fname_fig,format="svg", transparent=True, dpi=300) + +# def sign_test(data): +# seed=1999 +# random_state = check_random_state(seed) +# p=np.mean(data * random_state.choice([1, -1], len(data))) +# return p + +def theory_rdm(RSA_methods): + if RSA_methods=='RSA_ID': + GNW_rdm=np.zeros([63,63]) + GNW_rdm[0:21,0:21]=1 + GNW_rdm[0:21,42:63]=1 + GNW_rdm[42:63,0:21]=1 + GNW_rdm[42:63,42:63]=1 + + IIT_rdm=np.zeros([63,63]) + IIT_rdm[0:42,0:42]=1 + + theory_rdm=dict() + theory_rdm['IIT']=IIT_rdm + theory_rdm['GNW']=GNW_rdm + elif RSA_methods=='RSA_Cat'or'RSA_Ori': + GNW_rdm=np.zeros([84,84]) + GNW_rdm[0:21,0:21]=1 + GNW_rdm[0:21,63:84]=1 + GNW_rdm[63:84,0:21]=1 + GNW_rdm[63:84,63:84]=1 + + IIT_rdm=np.zeros([84,84]) + IIT_rdm[0:63,0:63]=1 + + theory_rdm=dict() + theory_rdm['IIT']=IIT_rdm + theory_rdm['GNW']=GNW_rdm + + return theory_rdm + +def corr_theory(rsa_subsample,analysis_name,decoding_name): + + #1:generated theory_rdm + theory_rdm_matrix=theory_rdm(analysis_name) + + #2:correlate the theories matrices with the observed matrices for each subjects + for n in range(len(sub_list)): + observed_matrix=rsa_subsample[n,:,:] + if n==0: + correlation_results, correlation_results_corrected=compute_correlation_theories([observed_matrix], theory_rdm_matrix, method="kendall") + group_corr_corrected=correlation_results_corrected + group_corr=correlation_results + else: + correlation_results, correlation_results_corrected=compute_correlation_theories([observed_matrix], theory_rdm_matrix, method="kendall") + group_corr_corrected=group_corr_corrected.append(correlation_results_corrected,ignore_index=True) + group_corr=group_corr.append(correlation_results,ignore_index=True) + + #stat + + p_value=dict() + stat,p_value['IIT']=stats.wilcoxon(group_corr['IIT']) + stat,p_value['GNW']=stats.wilcoxon(group_corr['GNW']) + stat,p_value['diff']=stats.mannwhitneyu(group_corr_corrected['GNW'],group_corr_corrected['IIT']) + + fname_p_value=op.join(stat_data_root, task_info +"_" + analysis_name + roi_name + decoding_name +'_stat_value.npz') + np.savez(fname_p_value,p_value,group_corr,group_corr_corrected) + + + corr_palette=[colors['IIT'],colors['GNW']] + #plot + group_corr_plot=group_corr.melt(var_name='theory',value_name='corr') + + fig, ax = plt.subplots(1) + ax=pt.RainCloud(x='theory',y='corr',data=group_corr_plot,palette=corr_palette,bw=.2,width_viol=.5,ax=ax,orient='v') + plt.title(analysis_name+'_corr_'+roi_name) + fname_corr_fig=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name+'_'+analysis_name + decoding_name +'_corr.svg') + fig.savefig(fname_corr_fig,format="svg", transparent=True, dpi=300) + +def RSA_ID_plot(roi_name): + analysis_name='RSA_ID' + time_point = np.array(range(-500, 1501, 10))/1000 + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name='ID',analysis=analysis_name) + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=201,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name+'_.svg') + #plot + rsa_ID_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5]]} + + rsa_subsample=np.zeros([len(sub_list),63,63]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 1.5, intervals_of_interest) + + + roi_rsa_mean=np.mean(rsa_subsample,0) + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name +'_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name='ID') + + +def RSA_Cat_plot(roi_name,condition): + analysis_name='RSA_Cat' + + time_point = np.array(range(-500,2001, 10))/1000 + if condition=='Irrelevant': + conD='IR' + elif condition=='Relevant non-target': + conD='RE' + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name=condition,analysis=analysis_name) + + + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=251,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name + '_' + conD +'.svg') + #plot + rsa_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]]} + + rsa_subsample=np.zeros([len(sub_list),84,84]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 2, intervals_of_interest) + + roi_rsa_mean=np.mean(rsa_subsample,0) + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name + '_' + conD + '_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name=condition) + +def RSA_Ori_plot(roi_name): + analysis='RSA_Ori' + time_point = np.array(range(-500,2001, 10))/1000 + + + #get decoding data + roi_rsa_g=rsa2gat(group_data,roi_name,cond_name=['rsa'],decoding_name='Ori',analysis=analysis) + + C1_stat=stat_cluster_1sample_RDM(roi_rsa_g,test_win_on=0, test_win_off=251,chance_index=0) + + fname_fig=op.join(stat_figure_root, task_info +"_" + analysis_name +'_' + roi_name +'.svg') + #plot + rsa_plot(roi_rsa_g,C1_stat=C1_stat,time_points=time_point,fname_fig=fname_fig) + + #subsample data + intervals_of_interest={"x": [[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]],"y":[[0.3,0.5],[0.8,1.0],[1.3,1.5],[1.8,2.0]]} + + rsa_subsample=np.zeros([len(sub_list),84,84]) + + for n in range(len(sub_list)): + rsa_subsample[n,:,:], subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict=subsample_matrices(roi_rsa_g[n,:,:], -0.5, 2, intervals_of_interest) + + roi_rsa_mean=np.mean(rsa_subsample,0) + + + + fname_fig_sub=op.join(stat_figure_root, task_info +"_" + analysis_name + roi_name +'_subsample.svg') + cmap = mpl.cm.RdYlBu_r + vmin=-0.02 + vmax=0.1 + #plot + rsa_subsample_plot(roi_rsa_mean, subsampled_time_ref, matrices_delimitations_ref, sub_matrix_dict,vmin,vmax,cmap,fname_fig_sub) + + #correlated with theory rdm + corr_theory(rsa_subsample,analysis_name,decoding_name='Ori') + + +def stat_cluster_1sample_RDM(gc_mean,test_win_on,test_win_off,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 1 # two-tailed + n_observations=gc_mean.shape[0] + stat_time_points=gc_mean[:,test_win_on:test_win_off,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[:,test_win_on:test_win_off,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points,stat_time_points])*chance_index, + threshold=thresh, n_permutations=1000, tail=tail, out_type='mask',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return C1_stat + + +################ +#set data root +if analysis_name=='RSA_Cat': + analysis_index='RSA_Cat_NoFS' + group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_index,con_C[0]) +else: + group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_name,con_C[0]) + + +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') + +fr=open(fname_data,'rb') +group_data=pickle.load(fr) + + + + +if analysis_name=='RSA_ID': + #sub_list.remove('SB006') + # GNW ROI + roi_name='GNW' + RSA_ID_plot(roi_name) + + # IIT ROI + roi_name='IIT' + RSA_ID_plot(roi_name) + +elif analysis_name=='RSA_Cat': + # GNW ROI + roi_name='GNW' + condition='Irrelevant' + RSA_Cat_plot(roi_name,condition) + condition='Relevant non-target' + RSA_Cat_plot(roi_name,condition) + + # IIT ROI + roi_name='IIT' + condition='Irrelevant' + RSA_Cat_plot(roi_name,condition) + condition='Relevant non-target' + RSA_Cat_plot(roi_name,condition) + +elif analysis_name=='RSA_Ori': + # GNW ROI + roi_name='GNW' + RSA_Ori_plot(roi_name) + + # IIT ROI + roi_name='IIT' + RSA_Ori_plot(roi_name) + + + diff --git a/roi_mvpa/D98_group_stat_sROI_plot_phaseII.py b/roi_mvpa/D98_group_stat_sROI_plot_phaseII.py new file mode 100644 index 0000000..766d0cb --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot_phaseII.py @@ -0,0 +1,1116 @@ +""" +==================== +D98. Group analysis for decoding +for phaseII +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os +import argparse + +import pickle +import mne + + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +sns.set_theme(style='ticks') + +from mne.stats import fdr_correction + +from scipy import stats as stats + + + +import matplotlib as mpl + + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root,plot_param + +from sublist_phase2 import sub_list + + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms', '1000ms', '1500ms'], + help='condition in Time duration') + +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--analysis', + type=str, + default='Cat', + help='the name for anlaysis, e.g. Cat or Ori or GAT_Cat') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_T = opt.cT + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + +# get the parameters dictionary +param = plot_param +colors=param['colors'] +fig_size = param["figure_size_mm"] +plt.rc('font', size=8) # controls default text size +plt.rc('axes', labelsize=20) +plt.rc('xtick',labelsize=18) +plt.rc('ytick',labelsize=18) +plt.rc('xtick.major', width=2, size=4) +plt.rc('ytick.major', width=2, size=4) +plt.rc('legend', fontsize=18) +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} +mpl.rcParams.update(new_rc_params) +def mm2inch(val): + return val / 25.4 + +# Color parameters: +cmap = "RdYlBu_r" +# #color_blind_palette = sns.color_palette("colorblind") +# colors = { +# "IIT": [ +# 0.00392156862745098, +# 0.45098039215686275, +# 0.6980392156862745 +# ], +# "GNW": [ +# 0.00784313725490196, +# 0.6196078431372549, +# 0.45098039215686275 +# ], +# "MT": [ +# 0.8352941176470589, +# 0.3686274509803922, +# 0.0 +# ], +# "FP": [ +# 0.5450980392156862, +# 0.16862745098039217, +# 0.8862745098039215 +# ], +# "Relevant to Irrelevant": [ +# 0.5450980392156862, +# 0.16862745098039217, +# 0.8862745098039215 +# ], +# "Irrelevant to Relevant": [ +# 0.8352941176470589, +# 0.3686274509803922, +# 0.0 +# ], +# "Relevant non-target": [ +# 0.8352941176470589, +# 0.3686274509803922, +# 0.0 +# ], +# "Irrelevant": [ +# 0.5450980392156862, +# 0.16862745098039217, +# 0.8862745098039215 +# ], +# "face": [ +# 0.00392156862745098, +# 0.45098039215686275, +# 0.6980392156862745 +# ], +# "object": [ +# 0.00784313725490196, +# 0.6196078431372549, +# 0.45098039215686275 +# ], +# "letter": [ +# 0.8352941176470589, +# 0.3686274509803922, +# 0.0 +# ], +# "false": [ +# 0.5450980392156862, +# 0.16862745098039217, +# 0.8862745098039215 +# ], +# } + + +time_point = np.array(range(-200,2001, 10))/1000 +# set the path for decoding analysis +def set_path_plot(bids_root, visit_id, analysis_name,con_name): + + ### I Set the group Data Path + # Set path to decoding derivatives + decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + + data_path=op.join(decoding_path, analysis_name) + + # Set path to group analysis derivatives + group_deriv_root = op.join(data_path, "group_phase2") + if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + # Set path to the ROI MVPA output(1) stat_data, 2) figures, 3) codes) + + # 1) output_stat_data + stat_data_root = op.join(group_deriv_root,"stat_data",con_name) + if not op.exists(stat_data_root): + os.makedirs(stat_data_root) + + # 2) output_figure + stat_figure_root = op.join(group_deriv_root,"stat_figures",con_name) + if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + return group_deriv_root,stat_data_root,stat_figure_root + +def df_plot(ts_df,T1,pval1,T2,pval2,time_point,test_win_on,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + #plot with sns + + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #leg = g._legend + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + reject_fdr1, pval_fdr1 = fdr_correction(pval1, alpha=0.05, method='indep') + temp=reject_fdr1.nonzero() + sig1=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr1 = np.min(np.abs(T1)[reject_fdr1]) + T11=np.concatenate((np.zeros((test_win_on-30,)),T1)) + clusters1 = np.where(T11 > threshold_fdr1)[0] + if len(clusters1)>1: + clusters1 = clusters1[clusters1 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters1], np.zeros(clusters1.shape) + 40, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=1 + + reject_fdr2, pval_fdr2 = fdr_correction(pval2, alpha=0.05, method='indep') + temp=reject_fdr2.nonzero() + sig2=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr2 = np.min(np.abs(T2)[reject_fdr2]) + T22=np.concatenate((np.zeros((test_win_on-30,)),T2)) + clusters2 = np.where(T22 > threshold_fdr2)[0] + if len(clusters2)>1: + clusters2 = clusters2[clusters2 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters2], np.zeros(clusters2.shape) + 30, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=1 + + #plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([15,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([20,40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1, sig2 + +def df_plot_cluster(ts_df,C1_stat,C2_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #leg = g._legend + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + 40, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + temp2=C2_stat['cluster'] + temp_p2=C2_stat['cluster_p'] + sig2=np.full(time_point.shape,np.nan) + if len(temp2)>=1: + for i in range(len(temp2)): + if temp_p2[i]<0.05:# plot the cluster which p < 0.05 + clusters2=temp2[i][0] + plt.plot(time_index[clusters2], np.zeros(clusters2.shape) + 30, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=i + + + + #plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([15,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([20,40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1, sig2 + +def df_plot_cluster_ori(ts_df,C1_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + if roi_name=='GNW': + window=[0.3,0.5,0.5,0.3] + elif roi_name=='IIT': + window=[0.3,1.5,1.5,0.3] + elif roi_name=='MT': + window=[0.25,0.5,0.5,0.25] + elif roi_name=='FP': + window=[0.3,1.5,1.5,0.3] + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('paper',rc=talk_rc,font_scale=4) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors,legend=False) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + + # leg = g._legend + # leg.remove() + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + 30, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + + + # plt.fill(window,[15,15,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([25,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([40,60,80,100]) + + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1 + +def df_plot_ROI_cluster(ts_df,C1_stat,time_point,test_win_on,test_win_off,chance_index,y_index,fname_fig): + + window=[0.3,1.5,1.5,0.3] + + + #plot with sns + + # talk_rc={'lines.linewidth':2,'lines.markersize':4} + # sns.set_context('talk',rc=talk_rc,font_scale=1) + + + g = sns.relplot(x="time(s)", y="decoding accuracy(%)", kind="line", data=ts_df,hue='ROI',aspect=2,palette=colors,legend=True) + g.fig.set_size_inches(mm2inch(fig_size[0]),mm2inch(fig_size[1])) + #sns.move_legend(g, "upper left", bbox_to_anchor=(.72, .8), frameon=False) + #leg = g._legend + # leg.remove() + #leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + #plt.axvline(0.5, color='gray', linestyle='--') + #plt.axvline(1, color='gray', linestyle='--') + #plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + chance_index-5, 'o', linewidth=3,color=colors['IIT']) + sig1[clusters1]=i + + + + #plt.fill(window,[40,40,100,100],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([25,100]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([40,60,80,100]) + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + return sig1 + +def g2gdat(roi_g,time_point,sig1,sig2): + roi_g_acc=np.mean(roi_g[:,:,30:251],axis=1) + roi_g_ci=1.96*stats.sem(roi_g[:,:,30:251],axis=1) + roi_g_dat=np.vstack((time_point,roi_g_acc,roi_g_ci,sig1,sig2)) + return roi_g_dat + +def g2gdat_ori(roi_g,time_point,sig1): + roi_g_acc=np.mean(roi_g[:,:,30:251],axis=1) + roi_g_ci=1.96*stats.sem(roi_g[:,:,30:251],axis=1) + roi_g_dat=np.vstack((time_point,roi_g_acc,roi_g_ci,sig1)) + return roi_g_dat + +def df2csv(np_data,task_index,csv_fname): + columns_index=['Time', + 'ACC (' + task_index[0] + ')','ACC (' + task_index[1] + ')', + 'CI (' + task_index[0] + ')','CI (' + task_index[1] + ')', + 'sig (' + task_index[0] + ')','sig (' + task_index[1] + ')'] + df = pd.DataFrame(np_data.T, columns=columns_index) + df.to_csv(csv_fname,sep=',',index=False,header=True,na_rep='NaN') + +def df2csv_ori(np_data,task_index,csv_fname): + columns_index=['Time', + 'ACC (' + task_index[0] + ')', + 'CI (' + task_index[0] + ')', + 'sig (' + task_index[0] + ')'] + df = pd.DataFrame(np_data.T, columns=columns_index) + df.to_csv(csv_fname,sep=',',index=False,header=True,na_rep='NaN') + +def gc2df(gc_mean,test_win_on,test_win_off,task_index,chance_index): + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T1, pval1 = stats.ttest_1samp(gc_mean[0,:,test_win_on:test_win_off], chance_index) + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T2, pval2 = stats.ttest_1samp(gc_mean[1,:,test_win_on:test_win_off], chance_index) + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,T1,pval1,T2,pval2 + +def stat_cluster_1sample(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T_obs_2, clusters_2, cluster_p_values_2, H0_2 = mne.stats.permutation_cluster_1samp_test( + gc_mean[1,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C2_stat=dict() + C2_stat['T_obs']=T_obs_2 + C2_stat['cluster']=clusters_2 + C2_stat['cluster_p']=cluster_p_values_2 + + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,C1_stat,C2_stat + +def stat_cluster_1sample_ori(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + + ts_df = pd.melt(df1, id_vars=['SUBID','Task'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + return ts_df,C1_stat + + +def stat_cluster_1sample_roi(ROI1_data,ROI2_data,test_win_on,test_win_off,ROI_name): + + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=ROI1_data.shape[1] + + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(ROI1_data[:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='ROI',value=ROI_name[0]) + + + + df2 = pd.DataFrame(ROI2_data[:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='ROI',value=ROI_name[1]) + + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','ROI'], var_name='time(s)', value_name='decoding accuracy(%)', value_vars=time_point) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_test( + [ROI1_data[:,test_win_on:test_win_off] , ROI2_data[:,test_win_on:test_win_off]], + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return ts_df,C1_stat + + +def dat2g(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=dat[sbn][decoding_name][roi_name][cond] + + + roi_ccd_g[ci,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + +def dat2g_PFC(dat,cond_name): + roi_wcd_g=np.zeros([3,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_wcd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_wcd_gc[i,:]=dat[sbn][cond] + roi_wcd_g[ci,:,:]=roi_wcd_gc*100 + + return roi_wcd_g + + +def dat2g_ori(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([1,len(sub_list),251]) + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=dat[sbn][decoding_name][roi_name][cond_name] + roi_ccd_g[0,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + + +def dat2gat(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251,251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251,251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:,:]=np.diagonal(dat[sbn][decoding_name][roi_name][cond]) + + + roi_ccd_g[ci,:,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + +def dat2gat2(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=np.diagonal(dat[sbn][decoding_name][roi_name][cond]) + + + roi_ccd_g[ci,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + + +def ccd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Relevant to Irrelevant','Irrelevant to Relevant'] + #get decoding data + ROI_ccd_g=dat2g(group_data,roi_name,cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_CCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat,C2_stat=stat_cluster_1sample(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_CCD_cluster" + '.svg') + + #plot + sig1_cluster,sig2_cluster=df_plot_cluster(ts_df_cluster,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #prepare data for plt plot + ROI_ccd_g_dat=g2gdat(ROI_ccd_g,time_point,sig1_cluster,sig2_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_CCD_cluster" + '.csv') + + df2csv(ROI_ccd_g_dat,task_index,csv_fname) + + +def wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI_wcd_g=dat2g(group_data,roi_name,cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat,C2_stat=stat_cluster_1sample(ROI_wcd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_cluster" + '.svg') + + #plot + sig1_cluster,sig2_cluster=df_plot_cluster(ts_df_cluster,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + #prepare data for plt plot + ROI_wcd_g_dat=g2gdat(ROI_wcd_g,time_point,sig1_cluster,sig2_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_cluster" + '.csv') + + df2csv(ROI_wcd_g_dat,task_index,csv_fname) + +def ROI_wcd_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI1_data=dat2g(group_data,ROI_name[0],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + ROI2_data=dat2g(group_data,ROI_name[1],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[0], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(ROI1_data[1,:,:],ROI2_data[1,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig2= op.join(stat_figure_root, task_index[1] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig2) + + + +def ROI_ccd_plt(group_data,decoding_method ='ccd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Relevant to Irrelevant','Irrelevant to Relevant'] + #get decoding data + ROI1_data=dat2g(group_data,ROI_name[0],cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + ROI2_data=dat2g(group_data,ROI_name[1],cond_name=['RE2IR','IR2RE'],decoding_name='ccd_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(ROI1_data[1,:,:],ROI2_data[1,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig2= op.join(stat_figure_root, task_index[1] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig2) + + + +def wcd_ori_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=conditions_C #Face/Object/Letter/False + #get decoding data + ROI_ori_g=dat2g_ori(group_data,roi_name,cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat=stat_cluster_1sample_ori(ROI_ori_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_ori_cluster" + '.svg') + + #plot + sig1_cluster=df_plot_cluster_ori(ts_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + #prepare data for plt plot + ROI_ori_g_dat=g2gdat_ori(ROI_ori_g,time_point,sig1_cluster) + + + csv_fname=op.join(stat_data_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_ori_cluster" + '.csv') + + df2csv_ori(ROI_ori_g_dat,task_index,csv_fname) + + +def ROI_wcd_ori_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40): + ROI_name=['IIT','FP'] + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI1_data=dat2g_ori(group_data,ROI_name[0],cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + ROI2_data=dat2g_ori(group_data,ROI_name[1],cond_name=conditions_C[0],decoding_name='wcd_ori_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(ROI1_data[0,:,:],ROI2_data[0,:,:],test_win_on,test_win_off,ROI_name) + + fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_FP_P_diff_acc_'+decoding_method + '_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + task_index=task_index[0], + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + + +######### +#set data root +group_deriv_root,stat_data_root,stat_figure_root=set_path_plot(bids_root,visit_id, analysis_name,con_C[0]) + + +# ######## +# #debug +# decoding_path=op.join(bids_root, "derivatives",'decoding') + +# data_path=op.join(decoding_path, analysis_name) + +# # Set path to group analysis derivatives +# group_deriv_root = op.join(data_path, "group") +# if not op.exists(group_deriv_root): +# os.makedirs(group_deriv_root) + + + + +# analysis/task info +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') + +fr=open(fname_data,'rb') +group_data=pickle.load(fr) + + + +if analysis_name=='Cat' or analysis_name=='Cat_offset_control': + #CCD: cross condition decoding + #GNW + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + ccd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # ccd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=251,chance_index=50,y_index=40) + + # 0ms to 1500ms + ccd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + + #WCD: within condition decoding + #GNW + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='GNW',test_win_on=130, test_win_off=150,chance_index=50,y_index=15) + + # 0ms to 1500ms + wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #IIT + + # # 300ms to 500ms + # wcd_plt(group_data2,roi_name='IIT',test_win_on=130, test_win_off=251,chance_index=50,y_index=40) + + # 0ms to 1500ms + wcd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + #compare IIT with IIT+GNW(FP) + ROI_ccd_plt(group_data,decoding_method ='ccd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + ROI_wcd_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + + +elif analysis_name=='Cat_MT_control': + ccd_plt(group_data,roi_name='MT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + wcd_plt(group_data,roi_name='MT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + +elif analysis_name=='Cat_baseline': + + wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + wcd_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + +elif analysis_name=='Ori': + + wcd_ori_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) + wcd_ori_plt(group_data,roi_name='IIT',test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) + +elif analysis_name=='Cat_PFC': + cond_name=['IIT','IITPFC_f','IITPFC_m'] + colors = { + "IIT": [1,0,0 + ], + "IITPFC_f": [0,0,1 + ], + "IITPFC_m": [0,0,1 + ]} + decoding_method=analysis_name + #task_index=['Irrelevant','Relevant non-target'] + #get decoding data + PFC_data=dat2g_PFC(group_data,cond_name) + + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + test_win_on=50 + test_win_off=200 + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[1,:,:],test_win_on,test_win_off,['IIT','IITPFC_f']) + + fname_cluster_fig= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_feature_diff_acc_cluster.svg') + + # fname_cluster_fig= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_feature_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + chance_index=50,y_index=50, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[2,:,:],test_win_on,test_win_off,['IIT','IITPFC_m']) + + fname_cluster_fig2= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_model_diff_acc_cluster.svg') + + # fname_cluster_fig2= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_model_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + chance_index=50,y_index=50, + fname_fig=fname_cluster_fig2) + + +elif analysis_name=='Ori_PFC': + cond_name=['IIT','IITPFC_f','IITPFC_m'] + colors = { + "IIT": [1,0,0 + ], + "IITPFC_f": [0,0,1 + ], + "IITPFC_m": [0,0,1 + ]} + decoding_method=analysis_name + #task_index=['Irrelevant','Relevant non-target'] + #get decoding data + PFC_data=dat2g_PFC(group_data,cond_name) + + + time_point = np.array(range(-200,2001, 10))/1000 + + #cluster based methods + test_win_on=50 + test_win_off=200 + #stat + ts1_df_cluster,C1_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[1,:,:],test_win_on,test_win_off,['IIT','IITPFC_f']) + + fname_cluster_fig= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_feature_diff_acc_cluster.svg') + + # fname_cluster_fig= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_feature_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts1_df_cluster,C1_stat,time_point, + test_win_on,test_win_off, + chance_index=33.3,y_index=50, + fname_fig=fname_cluster_fig) + + #stat + ts2_df_cluster,C2_stat=stat_cluster_1sample_roi(PFC_data[0,:,:],PFC_data[2,:,:],test_win_on,test_win_off,['IIT','IITPFC_m']) + + fname_cluster_fig2= op.join(stat_figure_root, decoding_method + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_IITPFC_model_diff_acc_cluster.svg') + + # fname_cluster_fig2= op.join(data_path, decoding_method + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_IITPFC_model_diff_acc_cluster.svg') + + #plot + sig1_cluster=df_plot_ROI_cluster(ts2_df_cluster,C2_stat,time_point, + test_win_on,test_win_off, + chance_index=33.3,y_index=50, + fname_fig=fname_cluster_fig2) + #ROI_wcd_ori_plt(group_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=33.3,y_index=40) diff --git a/roi_mvpa/D98_group_stat_sROI_plot_subROI_phaseII.py b/roi_mvpa/D98_group_stat_sROI_plot_subROI_phaseII.py new file mode 100644 index 0000000..00442b5 --- /dev/null +++ b/roi_mvpa/D98_group_stat_sROI_plot_subROI_phaseII.py @@ -0,0 +1,1373 @@ +""" +==================== +D98. Group analysis for decoding pattern +Category decoding +control analysis, +decoding at subROI. +==================== + +@author: Ling Liu ling.liu@pku.edu.cn + +""" + +import os.path as op +import os + +import pickle +import mne + + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +sns.set_theme(style='ticks') + +from mne.stats import fdr_correction +from scipy import stats as stats + + + +import matplotlib as mpl +from matplotlib import cm + + +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize + + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import plot_param +from sublist_phase2 import sub_list + +# get the parameters dictionary +param = plot_param +pcolors=param['colors'] +fig_size = param["figure_size_mm"] +# plt.rcParams["font.family"] = "serif" +# plt.rcParams["font.serif"] = "Times New Roman" +plt.rc('font', size=param["font_size"]*2) # controls default text sizes +plt.rc('axes', titlesize=param["font_size"]*2) # fontsize of the axes title +plt.rc('axes', labelsize=param["font_size"]*2) # fontsize of the x and y labels +plt.rc('xtick', labelsize=param["font_size"]*2) # fontsize of the tick labels +plt.rc('ytick', labelsize=param["font_size"]*2) # fontsize of the tick labels +plt.rc('legend', fontsize=param["font_size"]*2) # legend fontsize +plt.rc('figure', titlesize=param["font_size"]*2) # fontsize of the fi +new_rc_params = {'text.usetex': False, +"svg.fonttype": 'none' +} +mpl.rcParams.update(new_rc_params) + +def mm2inch(val): + return val / 25.4 + +#set data path +subjects_dir = r'Y:\HPC\fs' + + +# decoding_path=op.join(bids_root, "derivatives",'decoding') + +# data_path=op.join(decoding_path,'roi_mvpa') + +data_path=r'D:\COGITATE_xps\data_analysis\MSP\leakage_control' + +# Set path to group analysis derivatives +group_deriv_root = op.join(data_path, "group_phase2",) +if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + + + + +stat_figure_root = op.join(group_deriv_root,"figures") +if not op.exists(stat_figure_root): + os.makedirs(stat_figure_root) + + + +con_C = ['FO'] +con_D = ['Irrelevant', 'Relevant non-target'] +con_T = ['500ms','1000ms','1500ms'] + + + + +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) + + +# analysis/task info +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + +Ffname_data=op.join(data_path,'group_phase2', task_info +"_data_group_Cat_subF_control" + '.pickle') +Pfname_data=op.join(data_path,'group_phase2', task_info +"_data_group_Cat_subP_control" + '.pickle') + + +Ffr=open(Ffname_data,'rb') +Fgroup_data=pickle.load(Ffr) + + + +Pfr=open(Pfname_data,'rb') +Pgroup_data=pickle.load(Pfr) + + + +# Color parameters: +cmap = "RdYlBu_r" +#color_blind_palette = sns.color_palette("colorblind") +colors = { + "F1": [ + 0.00392156862745098, + 0.45098039215686275, + 0.6980392156862745 + ], + "F2": [ + 0.00784313725490196, + 0.6196078431372549, + 0.45098039215686275 + ], + "F3": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Middle frontal gyrus": [ + 0/255, + 0/255, + 130/255 + ], + "Inferior frontal sulcus": [ + 0/255, + 0/255, + 235/255 + ], + "Superios frontal sulcus": [ + 0/255, + 18/255, + 255/255 + ], + "Intraparietal sulcus & transverse parietal sulci": [ + 128/255, + 0/255, + 0/255 + ], + "Post-central sulcus": [ + 255/255, + 149/255, + 0/255 + ], + "Post-central gyrus": [ + 0/255, + 114/255, + 235/255 + ], + "Central sulcus": [ + 125/255, + 236/255, + 104/255 + ], + "Central gyrus": [ + 0/255, + 120/255, + 255/255 + ], + 'G_and_S_cingul-Ant': [ + 0/255, + 0/255, + 203/255 + ], + 'G_and_S_cingul-Mid-Ant': [ + 0/255, + 219/255, + 255/255 + ], + 'G_and_S_cingul-Mid-Post': [ + 133/255, + 255/255, + 143/255 + ], + "Precentral infrior sulcus": [ + 210/255, + 255/255, + 51/255 + ], + "Relevant non-target": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Irrelevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], +} + + +time_point = np.array(range(-200,2001, 10))/1000 + + + +def df_plot(ts_df,T1,pval1,T2,pval2,time_point,test_win_on,roi_name,task_index,chance_index,y_index,fname_fig): + window=[0.3,0.5,0.5,0.3] + # if roi_name=='GNW': + # window=[0.3,0.5,0.5,0.3] + # elif roi_name=='IIT': + # window=[0.3,1.5,1.5,0.3] + # elif roi_name=='MT': + # window=[0.25,0.5,0.5,0.25] + # elif roi_name=='FP': + # window=[0.3,1.5,1.5,0.3] + # #plot with sns + + + talk_rc={'lines.linewidth':2,'lines.markersize':4} + sns.set_context('talk',rc=talk_rc,font_scale=1) + + + g = sns.relplot(x="Times(s)", y="Accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors) + leg = g._legend + leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + plt.axvline(0.5, color='gray', linestyle='--') + plt.axvline(1, color='gray', linestyle='--') + plt.axvline(1.5, color='gray', linestyle='--') + + reject_fdr1, pval_fdr1 = fdr_correction(pval1, alpha=0.05, method='indep') + temp=reject_fdr1.nonzero() + sig1=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr1 = np.min(np.abs(T1)[reject_fdr1]) + T11=np.concatenate((np.zeros((test_win_on-30,)),T1)) + clusters1 = np.where(T11 > threshold_fdr1)[0] + if len(clusters1)>1: + clusters1 = clusters1[clusters1 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters1], np.zeros(clusters1.shape) + chance_index-4, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=1 + + reject_fdr2, pval_fdr2 = fdr_correction(pval2, alpha=0.05, method='indep') + temp=reject_fdr2.nonzero() + sig2=np.full(time_point.shape,np.nan) + if len(temp[0])>=1: + threshold_fdr2 = np.min(np.abs(T2)[reject_fdr2]) + T22=np.concatenate((np.zeros((test_win_on-30,)),T2)) + clusters2 = np.where(T22 > threshold_fdr2)[0] + if len(clusters2)>1: + clusters2 = clusters2[clusters2 > test_win_on-30] + #times = range(0, 500, 10) + plt.plot(time_point[clusters2], np.zeros(clusters2.shape) + chance_index-6, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=1 + + plt.fill(window,[chance_index-10,chance_index-10,chance_index+y_index,chance_index+y_index],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([chance_index-10,chance_index+y_index]) + + g.savefig(fname_fig) + + return sig1, sig2 + +def df_plot_cluster(ts_df,C1_stat,C2_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + # if roi_name=='GNW': + # window=[0.3,0.5,0.5,0.3] + # elif roi_name=='IIT': + # window=[0.3,1.5,1.5,0.3] + # elif roi_name=='MT': + # window=[0.25,0.5,0.5,0.25] + # elif roi_name=='FP': + # window=[0.3,1.5,1.5,0.3] + window=[0.3,0.5,0.5,0.3] + + #plot with sns + + talk_rc={'lines.linewidth':2,'lines.markersize':4} + sns.set_context('talk',rc=talk_rc,font_scale=1) + + + g = sns.relplot(x="Times(s)", y="Accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors) + leg = g._legend + leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + plt.axvline(0.5, color='gray', linestyle='--') + plt.axvline(1, color='gray', linestyle='--') + plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + chance_index-4, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + temp2=C2_stat['cluster'] + temp_p2=C2_stat['cluster_p'] + sig2=np.full(time_point.shape,np.nan) + if len(temp2)>=1: + for i in range(len(temp2)): + if temp_p2[i]<0.05:# plot the cluster which p < 0.05 + clusters2=temp2[i][0] + plt.plot(time_index[clusters2], np.zeros(clusters2.shape) + chance_index-6, 'o', linewidth=3,color=colors[task_index[1]]) + sig2[clusters2]=i + + + + plt.fill(window,[chance_index-10,chance_index-10,chance_index+y_index,chance_index+y_index],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([chance_index-10,chance_index+y_index]) + + g.savefig(fname_fig) + + return sig1, sig2 + +def df_plot_cluster_ori(ts_df,C1_stat,time_point,test_win_on,test_win_off,roi_name,task_index,chance_index,y_index,fname_fig): + # if roi_name=='GNW': + # window=[0.3,0.5,0.5,0.3] + # elif roi_name=='IIT': + # window=[0.3,1.5,1.5,0.3] + # elif roi_name=='MT': + # window=[0.25,0.5,0.5,0.25] + # elif roi_name=='FP': + # window=[0.3,1.5,1.5,0.3] + window=[0.3,0.5,0.5,0.3] + + #plot with sns + + talk_rc={'lines.linewidth':2,'lines.markersize':4} + sns.set_context('talk',rc=talk_rc,font_scale=1) + + + g = sns.relplot(x="Times(s)", y="Accuracy(%)", kind="line", data=ts_df,hue='Task',aspect=2,palette=colors) + leg = g._legend + leg.set_bbox_to_anchor([0.72,0.8]) + + plt.axhline(chance_index, color='k', linestyle='-', label='chance') + plt.axvline(0, color='k', linestyle='-', label='onset') + plt.axvline(0.5, color='gray', linestyle='--') + plt.axvline(1, color='gray', linestyle='--') + plt.axvline(1.5, color='gray', linestyle='--') + + + temp=C1_stat['cluster'] + temp_p=C1_stat['cluster_p'] + sig1=np.full(time_point.shape,np.nan) + time_index=time_point[(test_win_on-30):(test_win_off-30)] + if len(temp)>=1: + for i in range(len(temp)): + if temp_p[i]<0.05:# plot the cluster which p < 0.05 + clusters1=temp[i][0] + plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + chance_index-4, 'o', linewidth=3,color=colors[task_index[0]]) + sig1[clusters1]=i + + + + plt.fill(window,[chance_index-10,chance_index-10,chance_index+y_index,chance_index+y_index],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([chance_index-10,chance_index+y_index]) + + g.savefig(fname_fig) + + return sig1 + +def df_plot_ROI12_cluster(ts_df,time_point,test_win_on,test_win_off,task_index,chance_index,y_index,fname_fig): + + #window=[0.3,0.5,0.5,0.3] + + + #plot with sns + + talk_rc={'lines.linewidth':2,'lines.markersize':4} + sns.set_context('talk',rc=talk_rc,font_scale=2) + + + + + g = sns.relplot(x="Times(s)", y="Accuracy(%)", kind="line", data=ts_df,col='ROI',hue='ROI',aspect=4,palette=colors,col_wrap=5,legend=False) + g.map(plt.axhline, y=50, color='k', linestyle='-', label='chance') + g.map(plt.axvline, x=0, color='k', linestyle='-', label='onset') + g.map(plt.axvline, x=0.5, color='gray', linestyle='--') + + + + + g.fig.set_size_inches(mm2inch(fig_size[0])*5,mm2inch(fig_size[0])*2) + #leg = g._legend + #leg.set_bbox_to_anchor([0.72,0.8]) + + + + #g.map( plt.axvline(1, color='gray', linestyle='--')) + #g.map(plt.axvline(1.5, color='gray', linestyle='--')) + + + # temp=C1_stat['cluster'] + # temp_p=C1_stat['cluster_p'] + # sig1=np.full(time_point.shape,np.nan) + # time_index=time_point[(test_win_on-30):(test_win_off-30)] + # if len(temp)>=1: + # for i in range(len(temp)): + # if temp_p[i]<0.05:# plot the cluster which p < 0.05 + # clusters1=temp[i][0] + # plt.plot(time_index[clusters1], np.zeros(clusters1.shape) + chance_index-4, 'o', linewidth=3,color=colors[task_index[0]]) + # sig1[clusters1]=i + + + + #plt.fill(window,[chance_index-10,chance_index-10,chance_index+y_index,chance_index+y_index],facecolor='g',alpha=0.2) + plt.xlim([-0.2,2]) + plt.ylim([chance_index-10,chance_index+y_index]) + plt.xticks([0,0.5,1.0,1.5,2]) + plt.yticks([20,40,60,80,100]) + + + g.savefig(fname_fig,format="svg", transparent=True, dpi=300) + + + +def df2csv(np_data,task_index,csv_fname): + columns_index=['Time', + 'ACC (' + task_index[0] + ')','ACC (' + task_index[1] + ')', + 'CI (' + task_index[0] + ')','CI (' + task_index[1] + ')', + 'sig (' + task_index[0] + ')','sig (' + task_index[1] + ')'] + df = pd.DataFrame(np_data.T, columns=columns_index) + df.to_csv(csv_fname,sep=',',index=False,header=True,na_rep='NaN') + +def gc2df(gc_mean,test_win_on,test_win_off,task_index,chance_index): + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T1, pval1 = stats.ttest_1samp(gc_mean[0,:,test_win_on:test_win_off], chance_index) + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T2, pval2 = stats.ttest_1samp(gc_mean[1,:,test_win_on:test_win_off], chance_index) + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='Times(s)', value_name='Accuracy(%)', value_vars=time_point) + + return ts_df,T1,pval1,T2,pval2 + +def stat_cluster_1sample(gc_mean,test_win_on,test_win_off,task_index,chance_index): + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=gc_mean.shape[1] + stat_time_points=gc_mean[:,:,test_win_on:test_win_off].shape[2] + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(gc_mean[0,:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='Task',value=task_index[0]) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_1samp_test( + gc_mean[0,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + df2 = pd.DataFrame(gc_mean[1,:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='Task',value=task_index[1]) + + T_obs_2, clusters_2, cluster_p_values_2, H0_2 = mne.stats.permutation_cluster_1samp_test( + gc_mean[1,:,test_win_on:test_win_off]-np.ones([n_observations,stat_time_points])*chance_index, + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C2_stat=dict() + C2_stat['T_obs']=T_obs_2 + C2_stat['cluster']=clusters_2 + C2_stat['cluster_p']=cluster_p_values_2 + + + df=df1.append(df2) + + ts_df = pd.melt(df, id_vars=['SUBID','Task'], var_name='Times(s)', value_name='Accuracy(%)', value_vars=time_point) + + return ts_df,C1_stat,C2_stat + + + + +def stat_cluster_1sample_roi(ROI1_data,ROI2_data,ROI3_data,test_win_on,test_win_off,ROI_name): + + # define theresh + pval = 0.05 # arbitrary + tail = 0 # two-tailed + n_observations=ROI1_data.shape[1] + + df = n_observations - 1 # degrees of freedom for the test + thresh = stats.t.ppf(1 - pval / 2, df) # two-tailed, t distribution + + df1 = pd.DataFrame(ROI1_data[:,30:251], columns=time_point) + df1.insert(loc=0, column='SUBID', value=sub_list) + df1.insert(loc=0, column='ROI',value=ROI_name[0]) + + + + df2 = pd.DataFrame(ROI2_data[:,30:251], columns=time_point) + df2.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='ROI',value=ROI_name[1]) + + + + + df=df1.append(df2) + + + df3 = pd.DataFrame(ROI2_data[:,30:251], columns=time_point) + df3.insert(loc=0, column='SUBID', value=sub_list) + df2.insert(loc=0, column='ROI',value=ROI_name[1]) + + + ts_df = pd.melt(df, id_vars=['SUBID','ROI'], var_name='Times(s)', value_name='Accuracy(%)', value_vars=time_point) + + T_obs_1, clusters_1, cluster_p_values_1, H0_1 = mne.stats.permutation_cluster_test( + [ROI1_data[:,test_win_on:test_win_off] , ROI2_data[:,test_win_on:test_win_off]], + threshold=thresh, n_permutations=10000, tail=tail, out_type='indices',verbose=None) + + C1_stat=dict() + C1_stat['T_obs']=T_obs_1 + C1_stat['cluster']=clusters_1 + C1_stat['cluster_p']=cluster_p_values_1 + + return ts_df,C1_stat + + +def dat2g(dat,roi_name,cond_name,decoding_name): + roi_ccd_g=np.zeros([2,len(sub_list),251]) + for ci, cond in enumerate(cond_name): + roi_ccd_gc=np.zeros([len(sub_list),251]) + for i, sbn in enumerate(sub_list): + roi_ccd_gc[i,:]=dat[sbn][decoding_name][roi_name][cond] + + + roi_ccd_g[ci,:,:]=roi_ccd_gc*100 + + return roi_ccd_g + + +def wcd_plt(group_data,roi_name='GNW',test_win_on=50, test_win_off=200,chance_index=50,y_index=15): + + + time_point = np.array(range(-200,2001, 10))/1000 + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI_ccd_g=dat2g(group_data,roi_name,cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + + # #FDR methods + + # #stat + # ts_df_fdr,T1,pval1,T2,pval2=gc2df(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + # #plot + # fname_fdr_fig= op.join(stat_figure_root, roi_name + '_'+ str(test_win_on) + '_'+ str(test_win_off) +"_acc_WCD_fdr" + '.png') + + # sig1_fdr,sig2_fdr=df_plot(ts_df_fdr,T1,pval1,T2,pval2,time_point,test_win_on, + # roi_name,task_index=task_index, + # chance_index=chance_index,y_index=y_index,fname_fig=fname_fdr_fig) + + + + + #cluster based methods + + #stat + ts_df_cluster,C1_stat,C2_stat=stat_cluster_1sample(ROI_ccd_g,test_win_on,test_win_off,task_index=task_index,chance_index=chance_index) + + fname_cluster_fig= op.join(stat_figure_root, roi_name + '_'+str(test_win_on) + '_' + str(test_win_off)+"_acc_WCD_cluster" + '.svg') + + #plot + sig1_cluster,sig2_cluster=df_plot_cluster(ts_df_cluster,C1_stat,C2_stat,time_point, + test_win_on,test_win_off, + roi_name,task_index=task_index, + chance_index=chance_index,y_index=y_index, + fname_fig=fname_cluster_fig) + + + +def ROI10_wcd_plt(group_data1,group_data2,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=40): + ROI_name1=['P1','P2','P3','P4','P5','F1','F2','F3','F4','F5']# ,'P6','F6' + ROI_name2=['Intraparietal sulcus & transverse parietal sulci', + 'Post-central sulcus', + 'Post-central gyrus', + 'Central sulcus', + 'Central gyrus', + #'Precentral infrior sulcus', + 'G_and_S_cingul-Ant', + 'G_and_S_cingul-Mid-Ant', + 'G_and_S_cingul-Mid-Post', + 'Middle frontal gyrus', + 'Inferior frontal sulcus'] + #'Superios frontal sulcus'] + task_index=['Irrelevant','Relevant non-target'] + #get decoding data + ROI9_data=np.zeros([10,2,len(sub_list),251]) + for i in range(5): + print('i=',i) + ROI9_data[i]=dat2g(group_data1,ROI_name1[i],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + for i2 in range(5): + print('i2=',i2+5) + ROI9_data[i2+5]=dat2g(group_data2,ROI_name1[i2+5],cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + + time_point = np.array(range(-200,2001, 10))/1000 + + IR_df1 = pd.DataFrame(ROI9_data[0,0,:,30:251], columns=time_point) + IR_df1.insert(loc=0, column='SUBID', value=sub_list) + IR_df1.insert(loc=0, column='ROI',value=ROI_name2[0]) + + IR_df1 = pd.melt(IR_df1, id_vars=['SUBID','ROI'], var_name='Times(s)', value_name='Accuracy(%)', value_vars=time_point) + + + for i in range(9): + n=i+1 + IR_df2 = [] + IR_df2 = pd.DataFrame(ROI9_data[n,0,:,30:251], columns=time_point) + IR_df2.insert(loc=0, column='SUBID', value=sub_list) + IR_df2.insert(loc=0, column='ROI',value=ROI_name2[n]) + IR_df1=IR_df1.append(IR_df2) + + RE_df1 = pd.DataFrame(ROI9_data[0,1,:,30:251], columns=time_point) + RE_df1.insert(loc=0, column='SUBID', value=sub_list) + RE_df1.insert(loc=0, column='ROI',value=ROI_name2[0]) + + + for i in range(9): + n=i+1 + RE_df2 = [] + RE_df2 = pd.DataFrame(ROI9_data[n,1,:,30:251], columns=time_point) + RE_df2.insert(loc=0, column='SUBID', value=sub_list) + RE_df2.insert(loc=0, column='ROI',value=ROI_name2[n]) + RE_df1=RE_df1.append(RE_df2) + + RE_df1 = pd.melt(RE_df1, id_vars=['SUBID','ROI'], var_name='Times(s)', value_name='Accuracy(%)', value_vars=time_point) + + + #cluster based methods + + + + # fname_cluster_fig= op.join(stat_figure_root, task_index[0] + + # '_'+str(test_win_on) + '_' + str(test_win_off) + + # '_ROI9_acc_'+decoding_method + '_cluster.png') + + # #plot + + # df_plot_ROI9_cluster(IR_df1,time_point, + # test_win_on,test_win_off, + # task_index=task_index[0], + # chance_index=chance_index, + # y_index=y_index, + # fname_fig=fname_cluster_fig) + + #stat + + + fname_cluster_fig2= op.join(stat_figure_root, task_index[1] + + '_'+str(test_win_on) + '_' + str(test_win_off) + + '_ROI10_acc_'+decoding_method + '_cluster.svg') + + #plot + df_plot_ROI12_cluster(RE_df1,time_point, + test_win_on,test_win_off, + task_index=task_index[1], + chance_index=chance_index, + y_index=y_index, + fname_fig=fname_cluster_fig2) + + return IR_df1,RE_df1 + +IR_df1,RE_df1=ROI10_wcd_plt(Pgroup_data,Fgroup_data,decoding_method ='wcd', test_win_on=50, test_win_off=200,chance_index=50,y_index=30) + + + + + +# #ccd_plt(group_data,roi_name='MT',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F1',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F2',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F3',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F4',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F5',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) +# wcd_plt(Fgroup_data,roi_name='F6',test_win_on=50, test_win_off=200,chance_index=50,y_index=40) + +FROI6_g = np.zeros([6, 2, len(sub_list), 251]) +FROI_list=['F1','F2','F3','F4','F5','F6'] +for ri,rname in enumerate(FROI_list): + FROI6_g[ri] = dat2g(Fgroup_data,roi_name=rname,cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + +FROI6_g_1=np.mean(np.mean(FROI6_g[:,1,:,50:75],axis=2),axis=1) +FROI6_g_2=np.mean(np.mean(FROI6_g[:,1,:,75:100],axis=2),axis=1) + +Froi_g=FROI6_g_2 +#Froi_g[0:2]=ROI6_g_1[3:5] + +#subjects_dir =r'Y:\HPC\fs' + + +labels_parc_fs = mne.read_labels_from_annot(subject='fsaverage', parc='aparc.a2009s', subjects_dir=subjects_dir) + +F1_ts_list=['G_and_S_cingul-Ant'] +F2_ts_list=['G_and_S_cingul-Mid-Ant'] +F3_ts_list=['G_and_S_cingul-Mid-Post'] +F4_ts_list=['G_front_middle'] +F5_ts_list=['S_front_inf'] +F6_ts_list=['S_front_sup'] + +F1_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F1_ts_list: + F1_ts_index.append(ii) +F2_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F2_ts_list: + F2_ts_index.append(ii) +F3_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F3_ts_list: + F3_ts_index.append(ii) +F4_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F4_ts_list: + F4_ts_index.append(ii) + +F5_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F5_ts_list: + F5_ts_index.append(ii) +F6_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in F6_ts_list: + F6_ts_index.append(ii) + + +for ni, n_label in enumerate(F1_ts_index): + F1_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF1_label = F1_label + elif ni == 1: + lF1_label = F1_label + elif ni % 2 == 0: + rF1_label = rF1_label + F1_label # , hemi="both" + else: + lF1_label = lF1_label + F1_label + +for ni, n_label in enumerate(F2_ts_index): + F2_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF2_label = F2_label + elif ni == 1: + lF2_label = F2_label + elif ni % 2 == 0: + rF2_label = rF2_label + F2_label # , hemi="both" + else: + lF2_label = lF2_label + F2_label + +for ni, n_label in enumerate(F3_ts_index): + F3_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF3_label = F3_label + elif ni == 1: + lF3_label = F3_label + elif ni % 2 == 0: + rF3_label = rF3_label + F3_label # , hemi="both" + else: + lF3_label = lF3_label + F3_label + +for ni, n_label in enumerate(F4_ts_index): + F4_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF4_label = F4_label + elif ni == 1: + lF4_label = F4_label + elif ni % 2 == 0: + rF4_label = rF4_label + F4_label # , hemi="both" + else: + lF4_label = lF4_label + F4_label + +for ni, n_label in enumerate(F5_ts_index): + F5_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF5_label = F5_label + elif ni == 1: + lF5_label = F5_label + elif ni % 2 == 0: + rF5_label = rF5_label + F5_label # , hemi="both" + else: + lF5_label = lF5_label + F5_label + +for ni, n_label in enumerate(F6_ts_index): + F6_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rF6_label = F6_label + elif ni == 1: + lF6_label = F6_label + elif ni % 2 == 0: + rF6_label = rF6_label + F6_label # , hemi="both" + else: + lF6_label = lF6_label + F6_label + + + +PROI6_g = np.zeros([6, 2, len(sub_list), 251]) +PROI_list=['P1','P2','P3','P4','P5','P6'] +for ri,rname in enumerate(PROI_list): + PROI6_g[ri] = dat2g(Pgroup_data,roi_name=rname,cond_name=['Irrelevant','Relevant non-target'],decoding_name='wcd_acc') + +PROI6_g_1=np.mean(np.mean(PROI6_g[:,1,:,50:75],axis=2),axis=1) +PROI6_g_2=np.mean(np.mean(PROI6_g[:,1,:,75:100],axis=2),axis=1) + +Proi_g=PROI6_g_2 +#Proi_g[0:2]=ROI6_g_1[3:5] + +FProi_g=np.zeros([12,]) +FProi_g[:6]=FROI6_g_2 +FProi_g[6:]=PROI6_g_2 +#subjects_dir =r'Y:\HPC\fs' + + +labels_parc_fs = mne.read_labels_from_annot(subject='fsaverage', parc='aparc.a2009s', subjects_dir=subjects_dir) + +P1_ts_list=['S_intrapariet_and_P_trans'] +P2_ts_list=['S_postcentral'] +P3_ts_list=['G_postcentral'] +P4_ts_list=['S_central'] +P5_ts_list=['G_precentral'] +P6_ts_list=['S_precentral-inf-part'] + +P1_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P1_ts_list: + P1_ts_index.append(ii) +P2_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P2_ts_list: + P2_ts_index.append(ii) +P3_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P3_ts_list: + P3_ts_index.append(ii) +P4_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P4_ts_list: + P4_ts_index.append(ii) + +P5_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P5_ts_list: + P5_ts_index.append(ii) +P6_ts_index = [] +for ii in range(len(labels_parc_fs)): + label_name = [] + label_name = labels_parc_fs[ii].name + if label_name[:-3] in P6_ts_list: + P6_ts_index.append(ii) + + +for ni, n_label in enumerate(P1_ts_index): + P1_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP1_label = P1_label + elif ni == 1: + lP1_label = P1_label + elif ni % 2 == 0: + rP1_label = rP1_label + P1_label # , hemi="both" + else: + lP1_label = lP1_label + P1_label + +for ni, n_label in enumerate(P2_ts_index): + P2_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP2_label = P2_label + elif ni == 1: + lP2_label = P2_label + elif ni % 2 == 0: + rP2_label = rP2_label + P2_label # , hemi="both" + else: + lP2_label = lP2_label + P2_label + +for ni, n_label in enumerate(P3_ts_index): + P3_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP3_label = P3_label + elif ni == 1: + lP3_label = P3_label + elif ni % 2 == 0: + rP3_label = rP3_label + P3_label # , hemi="both" + else: + lP3_label = lP3_label + P3_label + +for ni, n_label in enumerate(P4_ts_index): + P4_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP4_label = P4_label + elif ni == 1: + lP4_label = P4_label + elif ni % 2 == 0: + rP4_label = rP4_label + P4_label # , hemi="both" + else: + lP4_label = lP4_label + P4_label + +for ni, n_label in enumerate(P5_ts_index): + P5_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP5_label = P5_label + elif ni == 1: + lP5_label = P5_label + elif ni % 2 == 0: + rP5_label = rP5_label + P5_label # , hemi="both" + else: + lP5_label = lP5_label + P5_label + +for ni, n_label in enumerate(P6_ts_index): + P6_label = [label for label in labels_parc_fs if label.name == labels_parc_fs[n_label].name][0] + if ni == 0: + rP6_label = P6_label + elif ni == 1: + lP6_label = P6_label + elif ni % 2 == 0: + rP6_label = rP6_label + P6_label # , hemi="both" + else: + lP6_label = lP6_label + P6_label + + + +#250-500ms +Brain = mne.viz.get_brain_class() +brain = Brain('fsaverage', 'rh', 'pial', subjects_dir=subjects_dir, + background='white', size=(800, 800), alpha=1) +# FROI_list=['F1_ts_index','F2_ts_index','F3_ts_index','F4_ts_index','F5_ts_index','F6_ts_index'] + +for ni, n_label in enumerate(F4_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F4_label, color=cmap(norm(FROI6_g_2[3])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F4_label, color=cmap(norm(FROI6_g_2[3])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(F5_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F5_label, color=cmap(norm(FROI6_g_2[4])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F5_label, color=cmap(norm(FROI6_g_2[4])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + + + + +for ni, n_label in enumerate(F6_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F6_label, color=cmap(norm(FROI6_g_2[5])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F6_label, color=cmap(norm(FROI6_g_2[5])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P1_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(P2_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P3_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(P4_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P4_label, color=cmap(norm(PROI6_g_2[3])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P4_label, color=cmap(norm(PROI6_g_2[3])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P5_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P5_label, color=cmap(norm(PROI6_g_2[4])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P5_label, color=cmap(norm(PROI6_g_2[4])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + + + + +for ni, n_label in enumerate(P6_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P6_label, color=cmap(norm(PROI6_g_2[5])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P6_label, color=cmap(norm(PROI6_g_2[5])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +views=['lateral','lateral'] +for view in views: + brain.show_view(view) + +pial_brain=brain.screenshot() + + + +# before/after results +# fig = plt.figure(figsize=(4, 12)) +# axes = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.5) +# for ax, image, title in zip( +# axes, [screenshot, cropped_screenshot], ["Before", "After"] +# ): +# ax.imshow(image) +# ax.set_title("{} cropping".format(title)) + +fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) +ax.imshow(pial_brain) +ax.axis("off") +file_path=stat_figure_root +fnamec='pial_brain.svg' +fname_fig_c=op.join(file_path,fnamec) + +fig.savefig(fname_fig_c,format="svg", transparent=True, dpi=300) + + +#250-500ms +Brain = mne.viz.get_brain_class() +brain = Brain('fsaverage', 'rh', 'inflated', subjects_dir=subjects_dir, + background='white', size=(800, 800), alpha=1) +# FROI_list=['F1_ts_index','F2_ts_index','F3_ts_index','F4_ts_index','F5_ts_index','F6_ts_index'] +for ni, n_label in enumerate(F1_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(F2_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(F3_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + + + + + + +for ni, n_label in enumerate(F4_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F4_label, color=cmap(norm(FROI6_g_2[3])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F4_label, color=cmap(norm(FROI6_g_2[3])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(F5_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F5_label, color=cmap(norm(FROI6_g_2[4])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F5_label, color=cmap(norm(FROI6_g_2[4])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + + +for ni, n_label in enumerate(F6_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F6_label, color=cmap(norm(FROI6_g_2[5])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F6_label, color=cmap(norm(FROI6_g_2[5])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P1_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P1_label, color=cmap(norm(PROI6_g_2[0])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(P2_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P2_label, color=cmap(norm(PROI6_g_2[1])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P3_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P3_label, color=cmap(norm(PROI6_g_2[2])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(P4_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P4_label, color=cmap(norm(PROI6_g_2[3])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P4_label, color=cmap(norm(PROI6_g_2[3])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(P5_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P5_label, color=cmap(norm(PROI6_g_2[4])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P5_label, color=cmap(norm(PROI6_g_2[4])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + + + + +for ni, n_label in enumerate(P6_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(P6_label, color=cmap(norm(PROI6_g_2[5])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(P6_label, color=cmap(norm(PROI6_g_2[5])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +views=['lateral','lateral'] +for view in views: + brain.show_view(view) + +inflated_brain=brain.screenshot() + + +fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) +ax.imshow(inflated_brain) +ax.axis("off") +file_path=stat_figure_root +fnamec='inflated_brain.svg' +fname_fig_c=op.join(file_path,fnamec) + +fig.savefig(fname_fig_c,format="svg", transparent=True, dpi=300) + + +#plot colorbar +fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[1])]) +ax.axis("off") +cax = fig.add_axes([0.5, 0.1, 0.05, 0.8]) +norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) +fig.colorbar(ScalarMappable(norm=norm, cmap=cm.get_cmap(cmap)), cax=cax,label="Decoding Accuracy (%)") +file_path=stat_figure_root +fnamec='decoding_plot_colorbar.svg' +fname_fig_c=op.join(file_path,fnamec) + +fig.savefig(fname_fig_c,format="svg", transparent=True, dpi=300) + + + +#250-500ms +Brain = mne.viz.get_brain_class() +brain = Brain('fsaverage', 'rh', 'inflated', subjects_dir=subjects_dir, + background='white', size=(800, 800), alpha=1) +# FROI_list=['F1_ts_index','F2_ts_index','F3_ts_index','F4_ts_index','F5_ts_index','F6_ts_index'] +for ni, n_label in enumerate(F1_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F1_label, color=cmap(norm(FROI6_g_2[0])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F1_label, color=cmap(norm(FROI6_g_2[0])), alpha=1, hemi="lh", + borders=False) # , hemi="both" +for ni, n_label in enumerate(F2_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F2_label, color=cmap(norm(FROI6_g_2[1])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F2_label, color=cmap(norm(FROI6_g_2[1])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +for ni, n_label in enumerate(F3_ts_index): + cmap = 'jet' + cmap = cm.get_cmap(cmap) + #norm = Normalize(vmin=0, vmax=6) + norm = Normalize(vmin=min(FProi_g), vmax=max(FProi_g)) + if (ni % 2) == 0: + brain.add_label(F3_label, color=cmap(norm(FROI6_g_2[2])), alpha=1, hemi="rh", + borders=False) # , hemi="both" + else: + brain.add_label(F3_label, color=cmap(norm(FROI6_g_2[2])), alpha=1, hemi="lh", + borders=False) # , hemi="both" + +views=['medial','medial'] +for view in views: + brain.show_view(view) + +inflated_brain=brain.screenshot() + +fig, ax = plt.subplots(figsize=[mm2inch(fig_size[0]),mm2inch(fig_size[0])]) +ax.imshow(inflated_brain) +ax.axis("off") +file_path=stat_figure_root +fnamec='inflated_brain_medial.svg' +fname_fig_c=op.join(file_path,fnamec) + +fig.savefig(fname_fig_c,format="svg", transparent=True, dpi=300) diff --git a/roi_mvpa/D99_group_data_pkl.py b/roi_mvpa/D99_group_data_pkl.py new file mode 100644 index 0000000..01bfaaa --- /dev/null +++ b/roi_mvpa/D99_group_data_pkl.py @@ -0,0 +1,202 @@ +""" +==================== +D99. Group analysis for decoding pattern +prepare for ploting + +==================== + +@author: ling liu ling.liu@pku.edu.cn +""" + +import os.path as op +import os +import argparse + +import pickle + +from config import bids_root +from sublist import sub_list + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms','1000ms','1500ms'], + help='condition in Time duration: [500ms],[1000ms],[1500ms]') +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--analysis', + type=str, + default='Cat', + help='the name for anlaysis, e.g. Cat or Ori or GAT_Cat') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT + +# if analysis_name=='Cat' or analysis_name=='Ori': +# if methods_name=='T_all': +# con_T=['500ms','1000ms','1500ms'] +# else: +# con_T = methods_name[0] + + +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + + + +decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + +data_path=op.join(decoding_path,analysis_name) + +# Set path to group analysis derivatives +group_deriv_root = op.join(data_path, "group") +if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + +# evokeds_group = [] +# stc_group = [] + + +sb_list=sub_list + + + +# analysis/task info +## analysis/task info + +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + +group_data=dict() +for i, sbn in enumerate(sb_list): + # if 'SB' in sbn: + # sub and visit info + sub_info = 'sub-' + sbn + '_ses-' + visit_id + + sub_data_root = op.join(data_path, + f"sub-{sbn}", f"ses-{visit_id}", "meg", + "data") + # if analysis_name=='Cat': + # pkl_name = "_ROIs_data_Cat" + # elif analysis_name=='Ori': + # pkl_name = "_ROIs_data_Ori" + # elif analysis_name=='GAT_Cat': + # pkl_name = "_ROIs_data_GAT_Cat" + # elif analysis_name=='GAT_Ori': + # pkl_name = "_ROIs_data_GAT_Cat" + # elif analysis_name=='RSA_Cat': + # pkl_name = "_ROIs_RSA_Cat" + # elif analysis_name=='RSA_Ori': + # pkl_name = "_ROIs_RSA_Ori" + # elif analysis_name=='RSA_ID': + # pkl_name = "_ROIs_RSA_ID" + rsa_data=dict() + if analysis_name == "RSA_Cat" or analysis_name=="RSA_Ori" or analysis_name=="RSA_ID": + for ri,roi_name in enumerate(['GNW','IIT']): + fname_data=op.join(sub_data_root, sub_info + '_' + task_info + roi_name +'_ROIs_data_'+ analysis_name +'.pickle') + fr=open(fname_data,'rb') + rsa_data[roi_name]=pickle.load(fr) + group_data[sbn]=rsa_data + + elif analysis_name == "Cat_PFC": + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_IITPFC_data_Cat.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + + elif analysis_name == "Ori_PFC": + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_IITPFC_data_Ori.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + + else: + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_ROIs_data_'+analysis_name +'.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') +fw = open(fname_data,'wb') +pickle.dump(group_data,fw) +fw.close() + + diff --git a/roi_mvpa/D99_group_data_pkl_phaseII.py b/roi_mvpa/D99_group_data_pkl_phaseII.py new file mode 100644 index 0000000..91d65b5 --- /dev/null +++ b/roi_mvpa/D99_group_data_pkl_phaseII.py @@ -0,0 +1,208 @@ +""" +==================== +D99. Group analysis for decoding pattern +prepare for ploting +==================== + +@author: ling liu ling.liu@pku.edu.cn + + + +""" + +import os.path as op +import os +import argparse + +import pickle + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import bids_root + +from sublist_phase2 import sub_list + +parser = argparse.ArgumentParser() +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--cT', type=str, nargs='*', default=['500ms','1000ms','1500ms'], + help='condition in Time duration: [500ms],[1000ms],[1500ms]') +parser.add_argument('--cC', type=str, nargs='*', default=['FO'], + help='selected decoding category, FO for face and object, LF for letter and false') +parser.add_argument('--cD',type=str,nargs='*', default=['Irrelevant', 'Relevant non-target'], + help='selected decoding Task, Relevant non Target or Irrelevant condition') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--analysis', + type=str, + default='Cat', + help='the name for anlaysis, e.g. Cat or Ori or GAT_Cat') +parser.add_argument('--nF', + type=int, + default=30, + help='number of feature selected for source decoding') +parser.add_argument('--nT', + type=int, + default=5, + help='number of trial averaged for source decoding') +parser.add_argument('--nPCA', + type=float, + default=0.95, + help='percentile of PCA selected for source decoding') + + +opt = parser.parse_args() + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path +analysis_name=opt.analysis + + +opt = parser.parse_args() +con_C = opt.cC +con_D = opt.cD +con_T = opt.cT + +# if analysis_name=='Cat' or analysis_name=='Ori': +# if methods_name=='T_all': +# con_T=['500ms','1000ms','1500ms'] +# else: +# con_T = methods_name[0] + + +select_F = opt.nF +n_trials = opt.nT +nPCA = opt.nPCA + + +visit_id = opt.visit +space = opt.space +subjects_dir = opt.fs_path + +#1) Select Category +if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) +elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) +elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) +elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) +elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) +elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + +#1) Select time duration +if con_T[0] == 'T_all': + con_T = ['500ms', '1000ms','1500ms'] + print(con_T) +elif con_T[0] == 'ML':# middle and long + con_T = ['1000ms','1500ms'] + print(con_T) + + + +decoding_path=op.join(bids_root, "derivatives",'decoding','roi_mvpa') + +data_path=op.join(decoding_path,analysis_name) + +# Set path to group analysis derivatives +group_deriv_root = op.join(data_path, "group_phase2") +if not op.exists(group_deriv_root): + os.makedirs(group_deriv_root) + +# evokeds_group = [] +# stc_group = [] + + +sb_list=sub_list + + + +# analysis/task info +## analysis/task info + +## analysis/task info +if con_T.__len__() == 3: + con_Tname = 'T_all' +elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] +else: + con_Tname = con_T[0] + +task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) +print(task_info) + +group_data=dict() +for i, sbn in enumerate(sb_list): + # if 'SB' in sbn: + # sub and visit info + sub_info = 'sub-' + sbn + '_ses-' + visit_id + + sub_data_root = op.join(data_path, + f"sub-{sbn}", f"ses-{visit_id}", "meg", + "data") + # if analysis_name=='Cat': + # pkl_name = "_ROIs_data_Cat" + # elif analysis_name=='Ori': + # pkl_name = "_ROIs_data_Ori" + # elif analysis_name=='GAT_Cat': + # pkl_name = "_ROIs_data_GAT_Cat" + # elif analysis_name=='GAT_Ori': + # pkl_name = "_ROIs_data_GAT_Cat" + # elif analysis_name=='RSA_Cat': + # pkl_name = "_ROIs_RSA_Cat" + # elif analysis_name=='RSA_Ori': + # pkl_name = "_ROIs_RSA_Ori" + # elif analysis_name=='RSA_ID': + # pkl_name = "_ROIs_RSA_ID" + rsa_data=dict() + if analysis_name == "RSA_Cat" or analysis_name=="RSA_Ori" or analysis_name=="RSA_ID": + for ri,roi_name in enumerate(['GNW','IIT']): + fname_data=op.join(sub_data_root, sub_info + '_' + task_info + roi_name +'_ROIs_data_'+ analysis_name +'.pickle') + fr=open(fname_data,'rb') + rsa_data[roi_name]=pickle.load(fr) + group_data[sbn]=rsa_data + + elif analysis_name == "Cat_PFC": + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_IITPFC_data_Cat.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + + elif analysis_name == "Ori_PFC": + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_IITPFC_data_Ori.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + + else: + fname_data=op.join(sub_data_root, sub_info + '_' + task_info +'_ROIs_data_'+analysis_name +'.pickle') + fr=open(fname_data,'rb') + roi_data=pickle.load(fr) + group_data[sbn]=roi_data + +fname_data=op.join(group_deriv_root, task_info +"_data_group_" + analysis_name + + '.pickle') +fw = open(fname_data,'wb') +pickle.dump(group_data,fw) +fw.close() + + diff --git a/roi_mvpa/D_MEG_function.py b/roi_mvpa/D_MEG_function.py new file mode 100644 index 0000000..a74c6c8 --- /dev/null +++ b/roi_mvpa/D_MEG_function.py @@ -0,0 +1,939 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 5 20:43:32 2022 + +@author: Ling Liu ling.liu@pku.edu.cn +================================= +Functions for MEG decoding +================================= +""" + +import os +import os.path as op + + +import mne +import numpy as np + +import sys +sys.path.insert(1, op.dirname(op.dirname(os.path.abspath(__file__)))) + +from config.config import l_freq, h_freq, sfreq + +from rsa_helper_functions import equate_offset +from mne.minimum_norm import apply_inverse_epochs + + + +# set the path for decoding analysis +def set_path_ROI_MVPA(bids_root,subject_id, visit_id, analysis_name): + ### I Set subject information + # sub and visit info + sub_info = 'sub-' + subject_id + '_ses-' + visit_id + print(sub_info) + + ### II Set the Input Data Path + # 1 Set path to the data root path + fpath_root = op.join(bids_root, "derivatives") #data_path + + # 2 Set path to preprocessed sensor (xxx_epo.fif) + fpath_epo = op.join(fpath_root, "preprocessing", + f"sub-{subject_id}", f"ses-{visit_id}", "meg") + + # 2 Set path to the preprocessed source model data + fpath_fw = op.join(fpath_root,'forward', f"sub-{subject_id}", "ses-" + visit_id, "meg") + + # 3 Set path to the freesufer subjects_dir for source analysis + fpath_fs=op.join(fpath_root, "fs") + # subjects_dir = r'/home/user/S10/Cogitate/HPC/fs' + + + ### III Set the Output Data Path + # Set path to decoding derivatives + mvpa_deriv_root = op.join(fpath_root, "decoding") + if not op.exists(mvpa_deriv_root): + os.makedirs(mvpa_deriv_root) + + + # Set path to the ROI MVPA output(1) data, 2) figures, 3) codes) + roi_deriv_root = op.join(mvpa_deriv_root, "roi_mvpa", analysis_name) + if not op.exists(roi_deriv_root): + os.makedirs(roi_deriv_root) + # 1) output_data + roi_data_root = op.join(roi_deriv_root, + f"sub-{subject_id}", f"ses-{visit_id}", "meg", + "data") + if not op.exists(roi_data_root): + os.makedirs(roi_data_root) + + # 2) output_figure + roi_figure_root = op.join(roi_deriv_root, + f"sub-{subject_id}", f"ses-{visit_id}", "meg", + "figures") + if not op.exists(roi_figure_root): + os.makedirs(roi_figure_root) + + # 3) output_code + roi_code_root = op.join(roi_deriv_root, + f"sub-{subject_id}", f"ses-{visit_id}", "meg", + "codes") + if not op.exists(roi_code_root): + os.makedirs(roi_code_root) + + return sub_info,fpath_epo,fpath_fw,fpath_fs, roi_data_root,roi_figure_root, roi_code_root + +# functions for use both spatial and temporal feature as the decoding feature +def STdata(Xraw): + #spatial + temporal decoding + # temporal feature window + #Xraw=epochs_cd.get_data() + twd=5 # how many time points will used as temporal feature + Xtemp=[]; + for twd_index in range(twd): + if twd_index==0: + #Xtemp1=np.append(Xraw[:,:,:1],Xraw[:,:,:-1],axis=2) + Xtemp = Xraw + else: + Xtemp1=np.append(Xraw[:,:,:twd_index],Xraw[:,:,:-twd_index],axis=2) + Xtemp=np.append(Xtemp,Xtemp1,axis=1) + + return Xtemp + +# sliding windows (twd,) for MEG data +def ATdata(Xraw): + + #Xraw=epochs_cd.get_data() + twd=5 # how many time points will be used as sliding windows + [t1,t2,t3]=Xraw.shape + Xtemp=np.zeros([5,t1,t2,t3]); + for twd_index in range(twd): + if twd_index==0: + #Xtemp1=np.append(Xraw[:,:,:1],Xraw[:,:,:-1],axis=2) + #Xtemp = np.expand_dims(Xraw, axis=0) + Xtemp[twd_index,:,:,:] = Xraw + else: + Xtemp1=np.append(Xraw[:,:,:twd_index],Xraw[:,:,:-twd_index],axis=2) + Xtemp[twd_index,:,:,:] = Xtemp1 + + Xnew=np.mean(Xtemp,axis=0) + + return Xnew + + +def sensor_data_for_ROI_MVPA(fpath_epo,sub_info,con_T,con_C,con_D): + ### Loading the epochs data + # fname_epo = file_name + fname_epo=op.join(fpath_epo,sub_info + '_task-dur_epo.fif') + epochs = mne.read_epochs(fname_epo, + preload=True, + verbose=True).pick('meg') + + ### Choose condition + # e.g + # conditions_T=['500ms','1000ms','1500ms'] + # conditions_D = ['Irrelevant', 'Relevant non-target'] + # conditions_C = ['face', 'object'] or conditions_C = ['letter', 'false'] + + + #1) Select Category + #1) Select Category + if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) + elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) + elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) + elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) + elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) + elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + + + epochs_cdc = epochs['Category in {}'.format(conditions_C)] + del epochs + + #2) Select Duration Time + conditions_T = con_T + print(conditions_T) + + epochs_cdd = epochs_cdc['Duration in {}'.format(conditions_T)] + del epochs_cdc + #3) Select Task relevance Design + + conditions_D = con_D + print(conditions_D) + + epochs_cd = epochs_cdd['Task_relevance in {}'.format(conditions_D)] + del epochs_cdd + + # Downsample and filter to speed the decoding + # Downsample copy of raw + epochs_rs = epochs_cd.copy().resample(sfreq, n_jobs=-1) + # Band-pass filter raw copy + epochs_rs.filter(l_freq, h_freq, n_jobs=-1) + + epochs_rs.crop(tmin=-0.5, tmax=2,include_tmax=True, verbose=None) + + # Baseline correction + b_tmin = -.5 + b_tmax = -.0 + baseline = (b_tmin, b_tmax) + epochs_rs.apply_baseline(baseline=baseline) + + # projecting sensor-space data to source space ###TODO:shrunk or ? + rank = mne.compute_rank(epochs_rs, tol=1e-6, tol_kind='relative') + + baseline_cov = mne.compute_covariance(epochs_rs, tmin=-0.5, tmax=0, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + active_cov = mne.compute_covariance(epochs_rs, tmin=0, tmax=2, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + + common_cov = baseline_cov + active_cov + + ## analysis/task info + if con_T.__len__() == 3: + con_Tname = 'T_all' + elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] + else: + con_Tname = con_T[0] + + task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) + print(task_info) + + return epochs_rs, rank, common_cov, conditions_C, conditions_D, conditions_T, task_info + +def sensor_data_for_ROI_MVPA_baseline(fpath_epo,sub_info,con_T,con_C,con_D): + ### Loading the epochs data + # fname_epo = file_name + fname_epo=op.join(fpath_epo,sub_info + '_task-dur_epo.fif') + epochs = mne.read_epochs(fname_epo, + preload=True, + verbose=True).pick('meg') + + ### Choose condition + # e.g + # conditions_T=['500ms','1000ms','1500ms'] + # conditions_D = ['Irrelevant', 'Relevant non-target'] + # conditions_C = ['face', 'object'] or conditions_C = ['letter', 'false'] + + + #1) Select Category + #1) Select Category + if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) + elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) + elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) + elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) + elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) + elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + + + epochs_cdc = epochs['Category in {}'.format(conditions_C)] + del epochs + + #2) Select Duration Time + conditions_T = con_T + print(conditions_T) + + epochs_cdd = epochs_cdc['Duration in {}'.format(conditions_T)] + del epochs_cdc + #3) Select Task relevance Design + + conditions_D = con_D + print(conditions_D) + + epochs_cd = epochs_cdd['Task_relevance in {}'.format(conditions_D)] + del epochs_cdd + + # Downsample and filter to speed the decoding + # Downsample copy of raw + epochs_rs = epochs_cd.copy().resample(sfreq, n_jobs=-1) + # Band-pass filter raw copy + epochs_rs.filter(l_freq, h_freq, n_jobs=-1) + + epochs_rs.crop(tmin=-0.5, tmax=2,include_tmax=True, verbose=None) + + # # Baseline correction + # b_tmin = -.5 + # b_tmax = -.0 + # baseline = (b_tmin, b_tmax) + # epochs_rs.apply_baseline(baseline=baseline) + + # projecting sensor-space data to source space ###TODO:shrunk or ? + rank = mne.compute_rank(epochs_rs, tol=1e-6, tol_kind='relative') + + baseline_cov = mne.compute_covariance(epochs_rs, tmin=-0.5, tmax=0, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + active_cov = mne.compute_covariance(epochs_rs, tmin=0, tmax=2, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + + common_cov = baseline_cov + active_cov + + ## analysis/task info + if con_T.__len__() == 3: + con_Tname = 'T_all' + elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] + else: + con_Tname = con_T[0] + + task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) + print(task_info) + + return epochs_rs, rank, common_cov, conditions_C, conditions_D, conditions_T, task_info + +def sensor_data_for_ROI_MVPA_equal_offset(fpath_epo,sub_info,con_T,con_C,con_D): + ### Loading the epochs data + # fname_epo = file_name + fname_epo=op.join(fpath_epo,sub_info + '_task-dur_epo.fif') + epochs = mne.read_epochs(fname_epo, + preload=True, + verbose=True).pick('meg') + + ### Choose condition + # e.g + # conditions_T=['500ms','1000ms','1500ms'] + # conditions_D = ['Irrelevant', 'Relevant non-target'] + # conditions_C = ['face', 'object'] or conditions_C = ['letter', 'false'] + + + #1) Select Category + if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) + elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) + elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) + elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) + elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) + elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + + epochs_cdc = epochs['Category in {}'.format(conditions_C)] + + #2) Select Duration Time + conditions_T = con_T + print(conditions_T) + + epochs_cdd = epochs_cdc['Duration in {}'.format(conditions_T)] + + #3) Select Task relevance Design + + conditions_D = con_D + print(conditions_D) + + epochs_cd = epochs_cdd['Task_relevance in {}'.format(conditions_D)] + + # Downsample and filter to speed the decoding + # Downsample copy of raw + epochs_rs_temp = epochs_cd.copy().resample(sfreq, n_jobs=-1) + # Band-pass filter raw copy + epochs_rs_temp.filter(l_freq, h_freq, n_jobs=-1) + + #equal offset for 1000ms and 1500ms + equate_offset_dict= { + "1500ms":{ + "excise_onset": 1.0, + "excise_offset": 1.5}, + "1000ms":{ + "excise_onset": 1.5, + "excise_offset": 2} + } + + epochs_rs=equate_offset(epochs_rs_temp, equate_offset_dict) + + epochs_rs.crop(tmin=-0.5, tmax=2,include_tmax=True, verbose=None) + + # Baseline correction + b_tmin = -.5 + b_tmax = -.0 + baseline = (b_tmin, b_tmax) + epochs_rs.apply_baseline(baseline=baseline) + + + + # projecting sensor-space data to source space ###TODO:shrunk or ? + rank = mne.compute_rank(epochs_rs, tol=1e-6, tol_kind='relative') + + baseline_cov = mne.compute_covariance(epochs_rs, tmin=-0.5, tmax=0, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + active_cov = mne.compute_covariance(epochs_rs, tmin=0, tmax=2, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + + common_cov = baseline_cov + active_cov + + ## analysis/task info + if con_T.__len__() == 3: + con_Tname = 'T_all' + elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] + else: + con_Tname = con_T[0] + + task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) + print(task_info) + + return epochs_rs, rank, common_cov, conditions_C, conditions_D, conditions_T, task_info + +def sensor_data_for_ROI_MVPA_ID(fpath_epo,sub_info,con_T,con_C,con_D,remove_too_few_trials=True): + ### Loading the epochs data + # fname_epo = file_name + fname_epo=op.join(fpath_epo,sub_info + '_task-dur_epo.fif') + #fname_epo=op.join(fpath_epo,fname) + epochs = mne.read_epochs(fname_epo, + preload=True, + verbose=True).pick('meg') + + ### Choose condition + # e.g + # conditions_T=['500ms','1000ms','1500ms'] + # conditions_D = ['Irrelevant', 'Relevant non-target'] + # conditions_C = ['face', 'object'] or conditions_C = ['letter', 'false'] + + + #1) Select Category + if con_C[0] == 'FO': + conditions_C = ['face', 'object'] + print(conditions_C) + elif con_C[0] == 'LF': + conditions_C = ['letter', 'false'] + print(conditions_C) + elif con_C[0] == 'F': + conditions_C = ['face'] + print(conditions_C) + elif con_C[0] == 'O': + conditions_C = ['object'] + print(conditions_C) + elif con_C[0] == 'L': + conditions_C = ['letter'] + print(conditions_C) + elif con_C[0] == 'FA': + conditions_C = ['false'] + print(conditions_C) + + epochs_cdc = epochs['Category in {}'.format(conditions_C)] + + #2) Select Duration Time + conditions_T = con_T + print(conditions_T) + + epochs_cdd = epochs_cdc['Duration in {}'.format(conditions_T)] + + #3) Select Task relevance Design + + conditions_D = con_D + print(conditions_D) + + epochs_cd = epochs_cdd['Task_relevance in {}'.format(conditions_D)] + + + #remove_too_few_trials: + min_n_repeats=2 + sub_metadata = epochs_cd.metadata.reset_index(drop=True) + # Find the identity for which we have less than two trials: + cts = sub_metadata.groupby(["Stim_trigger"])["Stim_trigger"].count() + id_to_remove = [identity for identity in cts.keys() if cts[identity] < min_n_repeats] + # Get the indices of the said identity to drop the trials: + id_idx = sub_metadata.loc[sub_metadata["Stim_trigger"].isin(id_to_remove)].index.values.tolist() + # Dropping those: + epochs_cd.drop(id_idx) + # epochs_cd = remove_too_few_trials(epochs_cd, condition="Stim_trigger", min_n_repeats=2, verbose=False) + + # Downsample and filter to speed the decoding + # Downsample copy of raw + epochs_rs_temp = epochs_cd.copy().resample(sfreq, n_jobs=-1) + # Band-pass filter raw copy + epochs_rs_temp.filter(l_freq, h_freq, n_jobs=-1) + + #equal offset for 1000ms and 1500ms + equate_offset_dict= { + "1500ms":{ + "excise_onset": 1.0, + "excise_offset": 1.5}, + "1000ms":{ + "excise_onset": 1.5, + "excise_offset": 2} + } + + epochs_rs=equate_offset(epochs_rs_temp, equate_offset_dict) + + epochs_rs.crop(tmin=-0.5, tmax=1.5,include_tmax=True, verbose=None) + + # Baseline correction + b_tmin = -.5 + b_tmax = -.0 + baseline = (b_tmin, b_tmax) + epochs_rs.apply_baseline(baseline=baseline) + + + + + # projecting sensor-space data to source space ###TODO:shrunk or ? + rank = mne.compute_rank(epochs_rs, tol=1e-6, tol_kind='relative') + + baseline_cov = mne.compute_covariance(epochs_rs, tmin=-0.5, tmax=0, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + active_cov = mne.compute_covariance(epochs_rs, tmin=0, tmax=2, method='empirical', rank=rank, n_jobs=-1, + verbose=True) + + common_cov = baseline_cov + active_cov + + ## analysis/task info + if con_T.__len__() == 3: + con_Tname = 'T_all' + elif con_T.__len__() == 2: + con_Tname = con_T[0]+'_'+con_T[1] + else: + con_Tname = con_T[0] + + task_info = "_" + "".join(con_Tname) + "_" + "".join(con_C[0]) + print(task_info) + + return epochs_rs, rank, common_cov, conditions_C, conditions_D, conditions_T, task_info + + +def source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label): + + # projecting sensor-space data to source space + # the path of forward solution + fname_fwd = op.join(fpath_fw, sub_info + "_surface_fwd.fif") + + fwd = mne.read_forward_solution(fname_fwd) + + #make inverse operator + # Make inverse operator + + inv = mne.minimum_norm.make_inverse_operator(epochs_rs.info, fwd, common_cov, + loose=.2,depth=.8,fixed=False, + rank=rank,use_cps=True) # cov= baseline + active, compute rank, same as the LCMV + + snr = 3.0 + lambda2 = 1.0 / snr ** 2 + stcs = apply_inverse_epochs(epochs_rs, inv, 1. / lambda2, 'dSPM', pick_ori="normal", label=surf_label) + + return stcs + +def sub_ROI_for_ROI_MVPA(fpath_fs,subject_id,analysis_name): + + # prepare the label for extract data + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + labels_parc_sub = mne.read_labels_from_annot(subject="fsaverage", + parc='aparc.a2009s', + subjects_dir=fpath_fs) + else: + labels_parc_sub = mne.read_labels_from_annot(subject=f"sub-{subject_id}", + parc='aparc.a2009s', + subjects_dir=fpath_fs) + + + # replace "&" and "_and_" for indisual MRI or fsaverage + if subject_id in ['SA102', 'SA104', 'SA110', 'SA111', 'SA152']: + #ROI info, could change ###TODO: the final defined ROI + GNW_ts_list = ['G_and_S_cingul-Ant','G_and_S_cingul-Mid-Ant', + 'G_and_S_cingul-Mid-Post', 'G_front_middle', + 'S_front_inf', 'S_front_sup', + ] + + PFC_ts_list = ['G_and_S_cingul-Ant','G_and_S_cingul-Mid-Ant', + 'G_and_S_cingul-Mid-Post', 'G_front_middle', 'S_front_sup', + ] #'S_front_inf' # remove S_front_inf, since this GNW ROI is also in the extented IIT ROI list. + + IIT_ts_list = ['G_cuneus', + 'G_oc-temp_lat-fusifor', 'G_oc-temp_med-Lingual', + 'Pole_occipital', 'S_calcarine', + 'S_oc_sup_and_transversal'] + + MT_ts_list = ['S_central','S_postcentral'] + + F1_ts_list=['G_and_S_cingul-Ant'] + F2_ts_list=['G_and_S_cingul-Mid-Ant'] + F3_ts_list=['G_and_S_cingul-Mid-Post'] + F4_ts_list=['G_front_middle'] + F5_ts_list=['S_front_inf'] + F6_ts_list=['S_front_sup'] + + P1_ts_list=['S_intrapariet_and_P_trans'] + P2_ts_list=['S_postcentral'] + P3_ts_list=['G_postcentral'] + P4_ts_list=['S_central'] + P5_ts_list=['G_precentral'] + P6_ts_list=['S_precentral-inf-part'] + + + else: + #ROI info, could change ###TODO: the final defined ROI + GNW_ts_list = ['G&S_cingul-Ant','G&S_cingul-Mid-Ant', + 'G&S_cingul-Mid-Post', 'G_front_middle', + 'S_front_inf', 'S_front_sup', + ] + + PFC_ts_list = ['G&S_cingul-Ant','G&S_cingul-Mid-Ant', + 'G&S_cingul-Mid-Post', 'G_front_middle', 'S_front_sup', + ] #'S_front_inf' # remove S_front_inf, since this GNW ROI is also in the extented IIT ROI list. + + + IIT_ts_list = ['G_cuneus', + 'G_oc-temp_lat-fusifor', 'G_oc-temp_med-Lingual', + 'Pole_occipital', 'S_calcarine', + 'S_oc_sup&transversal'] + + #MT_ts_list = ['S_central','S_postcentral'] + + MT_ts_list = ['S_central'] + + F1_ts_list=['G&S_cingul-Ant'] + F2_ts_list=['G&S_cingul-Mid-Ant'] + F3_ts_list=['G&S_cingul-Mid-Post'] + F4_ts_list=['G_front_middle'] + F5_ts_list=['S_front_inf'] + F6_ts_list=['S_front_sup'] + + P1_ts_list=['S_intrapariet&P_trans'] + P2_ts_list=['S_postcentral'] + P3_ts_list=['G_postcentral'] + P4_ts_list=['S_central'] + P5_ts_list=['G_precentral'] + P6_ts_list=['S_precentral-inf-part'] + + GNW_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in GNW_ts_list: + GNW_ts_index.append(ii) + + PFC_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in PFC_ts_list: + PFC_ts_index.append(ii) + + + + IIT_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in IIT_ts_list: + IIT_ts_index.append(ii) + + MT_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in MT_ts_list: + MT_ts_index.append(ii) + + F1_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F1_ts_list: + F1_ts_index.append(ii) + F2_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F2_ts_list: + F2_ts_index.append(ii) + F3_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F3_ts_list: + F3_ts_index.append(ii) + F4_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F4_ts_list: + F4_ts_index.append(ii) + + F5_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F5_ts_list: + F5_ts_index.append(ii) + F6_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in F6_ts_list: + F6_ts_index.append(ii) + + + + P1_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P1_ts_list: + P1_ts_index.append(ii) + P2_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P2_ts_list: + P2_ts_index.append(ii) + P3_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P3_ts_list: + P3_ts_index.append(ii) + P4_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P4_ts_list: + P4_ts_index.append(ii) + + P5_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P5_ts_list: + P5_ts_index.append(ii) + P6_ts_index = [] + for ii in range(len(labels_parc_sub)): + label_name = [] + label_name = labels_parc_sub[ii].name + if label_name[:-3] in P6_ts_list: + P6_ts_index.append(ii) + + for ni, n_label in enumerate(GNW_ts_index): + GNW_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rGNW_label = GNW_label + elif ni == 1: + lGNW_label = GNW_label + elif ni % 2 == 0: + rGNW_label = rGNW_label + GNW_label # , hemi="both" + else: + lGNW_label = lGNW_label + GNW_label + + for ni, n_label in enumerate(PFC_ts_index): + PFC_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rPFC_label = PFC_label + elif ni == 1: + lPFC_label = PFC_label + elif ni % 2 == 0: + rPFC_label = rPFC_label + PFC_label # , hemi="both" + else: + lPFC_label = lPFC_label + PFC_label + + for ni, n_label in enumerate(IIT_ts_index): + IIT_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rIIT_label = IIT_label + elif ni == 1: + lIIT_label = IIT_label + elif ni % 2 == 0: + rIIT_label = rIIT_label + IIT_label # , hemi="both" + else: + lIIT_label = lIIT_label + IIT_label + + for ni, n_label in enumerate(MT_ts_index): + MT_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rMT_label = MT_label + elif ni == 1: + lMT_label = MT_label + elif ni % 2 == 0: + rMT_label = rMT_label + MT_label # , hemi="both" + else: + lMT_label = lMT_label + MT_label + + for ni, n_label in enumerate(F1_ts_index): + F1_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF1_label = F1_label + elif ni == 1: + lF1_label = F1_label + elif ni % 2 == 0: + rF1_label = rF1_label + F1_label # , hemi="both" + else: + lF1_label = lF1_label + F1_label + + for ni, n_label in enumerate(F2_ts_index): + F2_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF2_label = F2_label + elif ni == 1: + lF2_label = F2_label + elif ni % 2 == 0: + rF2_label = rF2_label + F2_label # , hemi="both" + else: + lF2_label = lF2_label + F2_label + + for ni, n_label in enumerate(F3_ts_index): + F3_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF3_label = F3_label + elif ni == 1: + lF3_label = F3_label + elif ni % 2 == 0: + rF3_label = rF3_label + F3_label # , hemi="both" + else: + lF3_label = lF3_label + F3_label + + for ni, n_label in enumerate(F4_ts_index): + F4_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF4_label = F4_label + elif ni == 1: + lF4_label = F4_label + elif ni % 2 == 0: + rF4_label = rF4_label + F4_label # , hemi="both" + else: + lF4_label = lF4_label + F4_label + + for ni, n_label in enumerate(F5_ts_index): + F5_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF5_label = F5_label + elif ni == 1: + lF5_label = F5_label + elif ni % 2 == 0: + rF5_label = rF5_label + F5_label # , hemi="both" + else: + lF5_label = lF5_label + F5_label + + for ni, n_label in enumerate(F6_ts_index): + F6_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rF6_label = F6_label + elif ni == 1: + lF6_label = F6_label + elif ni % 2 == 0: + rF6_label = rF6_label + F6_label # , hemi="both" + else: + lF6_label = lF6_label + F6_label + + for ni, n_label in enumerate(P1_ts_index): + P1_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP1_label = P1_label + elif ni == 1: + lP1_label = P1_label + elif ni % 2 == 0: + rP1_label = rP1_label + P1_label # , hemi="both" + else: + lP1_label = lP1_label + P1_label + + for ni, n_label in enumerate(P2_ts_index): + P2_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP2_label = P2_label + elif ni == 1: + lP2_label = P2_label + elif ni % 2 == 0: + rP2_label = rP2_label + P2_label # , hemi="both" + else: + lP2_label = lP2_label + P2_label + + for ni, n_label in enumerate(P3_ts_index): + P3_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP3_label = P3_label + elif ni == 1: + lP3_label = P3_label + elif ni % 2 == 0: + rP3_label = rP3_label + P3_label # , hemi="both" + else: + lP3_label = lP3_label + P3_label + + for ni, n_label in enumerate(P4_ts_index): + P4_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP4_label = P4_label + elif ni == 1: + lP4_label = P4_label + elif ni % 2 == 0: + rP4_label = rP4_label + P4_label # , hemi="both" + else: + lP4_label = lP4_label + P4_label + + for ni, n_label in enumerate(P5_ts_index): + P5_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP5_label = P5_label + elif ni == 1: + lP5_label = P5_label + elif ni % 2 == 0: + rP5_label = rP5_label + P5_label # , hemi="both" + else: + lP5_label = lP5_label + P5_label + + for ni, n_label in enumerate(P6_ts_index): + P6_label = [label for label in labels_parc_sub if label.name == labels_parc_sub[n_label].name][0] + if ni == 0: + rP6_label = P6_label + elif ni == 1: + lP6_label = P6_label + elif ni % 2 == 0: + rP6_label = rP6_label + P6_label # , hemi="both" + else: + lP6_label = lP6_label + P6_label + + + + + if analysis_name=='Cat' or analysis_name=='Ori' or analysis_name=='Cat_offset_control': + surf_label_list = [rGNW_label+lGNW_label, rIIT_label+lIIT_label,rGNW_label+lGNW_label+rIIT_label+lIIT_label] + ROI_Name = ['GNW', 'IIT','FP'] + + elif analysis_name=='Cat_MT_control': + surf_label_list = [rMT_label+lMT_label] + ROI_Name = ['MT'] + + elif analysis_name=='Cat_subF_control': + surf_label_list = [rF1_label+lF1_label,rF2_label+lF2_label,rF3_label+lF3_label, + rF4_label+lF4_label,rF5_label+lF5_label,rF6_label+lF6_label] + ROI_Name = ['F1','F2','F3','F4','F5','F6'] + + elif analysis_name=='Cat_subP_control': + surf_label_list = [rP1_label+lP1_label,rP2_label+lP2_label,rP3_label+lP3_label, + rP4_label+lP4_label,rP5_label+lP5_label,rP6_label+lP6_label] + ROI_Name = ['P1','P2','P3','P4','P5','P6'] + + elif analysis_name=='Cat_PFC' or analysis_name=='Ori_PFC': + surf_label_list = [rPFC_label+lPFC_label, rIIT_label+lIIT_label,rPFC_label+lPFC_label+rIIT_label+lIIT_label] + ROI_Name = ['PFC', 'IIT','IITPFC'] + + else: + surf_label_list = [rGNW_label+lGNW_label, rIIT_label+lIIT_label] + ROI_Name = ['GNW', 'IIT'] + + return surf_label_list, ROI_Name diff --git a/roi_mvpa/about.md b/roi_mvpa/about.md new file mode 100644 index 0000000..8f1d866 --- /dev/null +++ b/roi_mvpa/about.md @@ -0,0 +1,45 @@ +# About + +## MEG_ROI_MVPA +This folder contains code to perform the MVPA analysis on ROI based MEG source signal. +Contributors: Ling Liu + +Usage: + +1.0 Subject level analysis: +Use D0X_ROI_MVPA_XX.py to run subject level analysis, each code represent a analysis method. +To run the Face vs Object Category decoding analysis for Subject SA001 for experiment 1, simply use the parameter. + +`python D01_ROI_MVPA_Cat.py --sub SA001 --visit V1 --cC FO` + +2.0 Group level anaylsis: +D99_group_data_xx.py is used for concatenate individual subject data to one file of group data. +To concatenate Face vs Object Category decoding analysis, simply use the parameter: + +`python D99_group_data_pkl.py --cC FO --analysis Cat` + +3.0 Group level statistical analysis and plotting +D98_group_stat_sROI_xx.py is used for generate final results figure with the data that generated from D99_group_data_xx.py code. To generate the main Figure of Category decoding analysis, simply use the parameter: + +`Python D99_group_stat_sROI_plot.py --cC FO --analysis Cat` + +- config file contain the parameter used for MEG analysis +- D_MEG_function.py contain the function used on ROI_MVPA analysis +- rsa_helper_functions_meg.py revised from ieeg team rsa analysis function, used for RSA analysis +- Sublist.py/sublist_phase2.py is subject list for ROI_MVPA analysis + +## Information + +| | | +| --- | --- | +author_name | Ling Liu +author_affiliation | 1 School of Psychological and Cognitive Sciences, Peking University, Beijing, China; 2 School of Communication Science, Beijing Language and Culture University, Beijing, China +author_email | ling.liu@pku.edu.cn +PI_name | Huan Luo +PI_affiliation | School of Psychological and Cognitive Sciences, Peking University, Beijing, China +PI_email | huan.luo@pku.edu.cn +programming_language | python +Is a readme file included with detailed instructions for running the code? | Yes. MEG_ROI_MVPA_readme. +Is the environment file provided? | No. +Is there a config file provided to change runtime parameters? | config.py +Does the code run on the sample dataset? | no diff --git a/roi_mvpa/config.py b/roi_mvpa/config.py new file mode 100644 index 0000000..426fa7c --- /dev/null +++ b/roi_mvpa/config.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +""" +=========== +Config file +=========== + +Configurate the parameters of the study. + +""" + +import os + +# ============================================================================= +# BIDS SETTINGS +# ============================================================================= +# if os.getlogin() in ['oscfe', 'ferranto', 'FerrantO']: #TODO: doesn't work on the HPC +# bids_root = r'Z:\_bids_' +# else: +bids_root = r'/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids' + + +# ============================================================================= +# MAXWELL FILTERING SETTINGS +# ============================================================================= + +# Set filtering method +method='sss' +if method == 'tsss': + st_duration = 10 +else: + st_duration = None + + +# ============================================================================= +# FILTERING AND DOWNSAMPLING SETTINGS +# ============================================================================= + +# Filter and resampling params +l_freq = 1 +h_freq = 40 +sfreq = 100 + + +# ============================================================================= +# EPOCHING SETTINGS +# ============================================================================= + +# Set timewindow +tmin = -1 +tmax = 2.5 + +# Epoch rejection criteria +reject_meg_eeg = dict(grad=4000e-13, # T / m (gradiometers) + mag=6e-12 # T (magnetometers) + #eeg=200e-6 # V (EEG channels) + ) +reject_meg = dict(grad=4000e-13, # T / m (gradiometers) + mag=6e-12 # T (magnetometers) + ) + + +# ============================================================================= +# ICA SETTINGS +# ============================================================================= + +ica_method = 'fastica' +n_components = 0.99 +max_iter = 800 +random_state = 1688 + + +# ============================================================================= +# FACTOR AND CONDITIONS OF INTEREST +# ============================================================================= + +# factor = ['Category'] +# conditions = ['face', 'object', 'letter', 'false'] + +# factor = ['Duration'] +# conditions = ['500ms', '1000ms', '1500ms'] + +# factor = ['Task_relevance'] +# conditions = ['Relevant_target','Relevant_non-target','Irrelevant'] + +# factor = ['Duration', 'Task_relevance'] +# conditions = [['500ms', '1000ms', '1500ms'], +# ['Relevant target','Relevant non-target','Irrelevant']] + +factor = ['Category', 'Task_relevance'] +conditions = [['face', 'object', 'letter', 'false'], + ['Relevant target','Relevant non-target','Irrelevant']] + + +# ============================================================================= +# TIME-FREQUENCY REPRESENTATION SETTINGS +# ============================================================================= + +baseline_w = [-0.5, -0.25] #only for plotting +freq_band = 'both' #can be 'low', 'high' or 'both' + + +# ============================================================================= +# SOURCE MODELING +# ============================================================================= + +# Forward model +spacing='oct6' #from forward_model + +# Inverse model +# Beamforming +beam_method = 'dics' #'lcmv' or 'dics' + +active_win = (0.75, 1.25) +baseline_win = (-.5, 0) + + +# ============================================================================= +# PLOTTING +# ============================================================================= + +# Subset of posterior sensors +post_sens = {'grad': ['MEG1932', 'MEG1933', 'MEG2122', 'MEG2123', + 'MEG2332', 'MEG2333', 'MEG1922', 'MEG1923', + 'MEG2112', 'MEG2113', 'MEG2342', 'MEG2343'], + 'mag': ['MEG1931', 'MEG2121', + 'MEG2331', 'MEG1921', + 'MEG2111', 'MEG2341'], + 'eeg': ['EEG056', 'EEG030', + 'EEG057', 'EEG018', + 'EEG032', 'EEG019']} + +plot_param = { + "font": "times new roman", + "font_size": 22, + "figure_size_mm": [183, 108], + "fig_res_dpi": 300, + "colors": { + "iit": [ + 0.00392156862745098, + 0.45098039215686275, + 0.6980392156862745], + "gnw": [ + 0.00784313725490196, + 0.6196078431372549, + 0.45098039215686275 + ], + "IIT": [ + 0.00392156862745098, + 0.45098039215686275, + 0.6980392156862745 + ], + "GNW": [ + 0.00784313725490196, + 0.6196078431372549, + 0.45098039215686275 + ], + "MT": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "FP": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "IITPFC_f": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "Relevant to Irrelevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "Irrelevant to Relevant": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Relevant non-target": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Irrelevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + "task relevant": [ + 0.8352941176470589, + 0.3686274509803922, + 0.0 + ], + "Irrelevant": [ + 0.5450980392156862, + 0.16862745098039217, + 0.8862745098039215 + ], + # "face": [ + # 0.00784313725490196, + # 0.24313725490196078, + # 1.0 + # ], + # "object": [ + # 0.10196078431372549, + # 0.788235294117647, + # 0.2196078431372549 + # ], + # "letter": [ + # 0.9098039215686274, + # 0.0, + # 0.043137254901960784 + # ], + # "false": [ + # 0.9450980392156862, + # 0.2980392156862745, + # 0.7568627450980392 + # ], + "500ms": [ + 1.0, + 0.48627450980392156, + 0.0 + ], + "1000ms": [ + 0.6235294117647059, + 0.2823529411764706, + 0.0 + ], + "1500ms": [ + 1.0, + 0.7686274509803922, + 0.0 + ], + "face": [ + 0/255, + 53/255, + 68/255 + ], + "object": [ + 173/255, + 80/255, + 29/255 + ], + "letter": [ + 57/255, + 115/255, + 132/255 + ], + "false": [ + 97/255, + 15/255, + 0/255 + ], + "cmap": "RdYlBu_r" + } +} diff --git a/roi_mvpa/rsa_helper_functions_meg.py b/roi_mvpa/rsa_helper_functions_meg.py new file mode 100644 index 0000000..0ba26ca --- /dev/null +++ b/roi_mvpa/rsa_helper_functions_meg.py @@ -0,0 +1,1153 @@ +""" +Alex's code for RSA analysis for ECOG's data + +@author: Alex + +revised by Ling +""" + +import random +import numpy as np +from sklearn.preprocessing import StandardScaler +from sklearn.model_selection import StratifiedKFold, KFold +from sklearn.feature_selection import SelectKBest, f_classif +from scipy.spatial.distance import cdist +from collections import Counter +from skimage.measure import block_reduce +import mne +import statsmodels.api as sm +import pandas as pd +from statsmodels.stats import multitest +import mne.stats +import pingouin as pg +from mne.stats.cluster_level import _pval_from_histogram + +from D_MEG_function import ATdata + + +# def source_data_for_ROI_MVPA(epochs_rs, fpath_fw, rank, common_cov, sub_info, surf_label): +# fwd = mne.read_forward_solution(fpath_fw) + +# # make inverse operator +# # cov= baseline + active, compute rank, same as the LCMV +# inv = mne.minimum_norm.make_inverse_operator(epochs_rs.info, fwd, common_cov, +# loose=.2, depth=.8, fixed=False, +# rank=rank, use_cps=True) + +# snr = 3.0 +# lambda2 = 1.0 / snr ** 2 +# stcs = apply_inverse_epochs( +# epochs_rs, inv, 1. / lambda2, 'dSPM', pick_ori="normal", label=surf_label) + +# return stcs + + +def pseudotrials_rsa(x, y, n_pseudotrials, times, sample_rdm_times=None, n_features=30, feat_sel_diag=True): + """ + This function computes pseudotrials before running the corrected within class correlation cross temporal RSA + :param x: (np array trials x channels/vertices x time points) data to compute the RSA + :param y: (np array) label of each trial. Must match x first dimension + :param n_pseudotrials: (int) number of trials to average together + :param times: (numpy array) times in secs of the time points axis of x + :param sample_rdm_times: (list) time points on which to do the sample RDM on + :param n_features: (int) number of features to use + :param feat_sel_diag: (bool) whether or not to do feature selection along the diagonal only. These reflect two + different implementations of the feature selection, be careful how you use it! + :return: + """ + + if sample_rdm_times is None: + sample_rdm_times = [0.2, 0.5] + + # Compute the pseudotrials separately for each condition: + conds = np.unique(y) + pseudotrials = [] + pseudotrials_labels = [] + for ind, cond in enumerate(conds): + cond_inds = np.where(y == cond)[0] + np.random.shuffle(cond_inds) # Shuffle the indices to avoid always taking the same trials together + # Compute the pseudotrials: + pseudotrials.append(block_reduce(x[cond_inds, :, :], + block_size=tuple([n_pseudotrials, + *[1] * len(x[cond_inds, :, :].shape[1:])]), + func=np.nanmean, cval=np.nan)) + pseudotrials_labels.append([cond] * pseudotrials[ind].shape[0]) + # Stack everything back together: + data = np.concatenate(pseudotrials) + label = np.array([item for sublist in pseudotrials_labels for item in sublist]) + + # temporal smooth data + data=ATdata(data) + + cross_temporal_mat_a, sample_rdm_a, sel_features = \ + within_vs_between_cross_temp_rsa_alex(data, + label, + metric='euclidean', + zscore=False, + onset_offset=[times[0], + times[-1]], + sample_rdm_times=sample_rdm_times, + n_features=n_features, + n_folds=5, + shuffle_labels=False, + verbose=True, + feat_sel_diag=feat_sel_diag, + store_intermediate=False) + return cross_temporal_mat_a, sample_rdm_a, sel_features + +def pseudotrials_rsa_all2all(x, y, n_pseudotrials, times, sample_rdm_times=None, n_features=30, metric="correlation",fisher_transform=True,feat_sel_diag=True): + """ + This function computes pseudotrials before running the corrected within class correlation cross temporal RSA + :param x: (np array trials x channels/vertices x time points) data to compute the RSA + :param y: (np array) label of each trial. Must match x first dimension + :param n_pseudotrials: (int) number of trials to average together + :param times: (numpy array) times in secs of the time points axis of x + :param sample_rdm_times: (list) time points on which to do the sample RDM on + :param n_features: (int) number of features to use + :param feat_sel_diag: (bool) whether or not to do feature selection along the diagonal only. These reflect two + different implementations of the feature selection, be careful how you use it! + :return: + """ + + if sample_rdm_times is None: + sample_rdm_times = [0.2, 0.5] + + # Compute the pseudotrials separately for each condition: + conds = np.unique(y) + pseudotrials = [] + pseudotrials_labels = [] + for ind, cond in enumerate(conds): + cond_inds = np.where(y == cond)[0] + np.random.shuffle(cond_inds) # Shuffle the indices to avoid always taking the same trials together + # Compute the pseudotrials: + pseudotrials.append(block_reduce(x[cond_inds, :, :], + block_size=tuple([n_pseudotrials, + *[1] * len(x[cond_inds, :, :].shape[1:])]), + func=np.nanmean, cval=np.nan)) + pseudotrials_labels.append([cond] * pseudotrials[ind].shape[0]) + # Stack everything back together: + data = np.concatenate(pseudotrials) + label = np.array([item for sublist in pseudotrials_labels for item in sublist]) + + # temporal smooth data + data=ATdata(data) + + cross_temporal_mat_a, rdm_diag, sel_features = all_to_all_within_class_dist(data,label, + metric=metric, + n_bootsstrap=20, + shuffle_labels=False, + fisher_transform=fisher_transform, + verbose=True, + n_features=n_features, + n_folds=5, + feat_sel_diag=feat_sel_diag) + return cross_temporal_mat_a, rdm_diag, sel_features + + +# v3 version, error with feature selection +# def all_to_all_within_class_dist(data, labels, metric="correlation", n_bootsstrap=None, shuffle_labels=False, +# fisher_transform=True, verbose=False, n_features=None, n_folds=None): +# """ +# This function computes all trials to all trials distances and computes within class correlated distances in a +# cross temporal fashion. +# :param data: +# :param labels: +# :param metric: +# :param n_bootsstrap: +# :param shuffle_labels: +# :param fisher_transform: +# :param verbose: +# :param n_features: +# :param n_folds: +# :return: +# """ +# if verbose: +# print("=" * 40) +# print("Welcome to cross_identity_cross_temp_rsm") +# # Make sure the labels are a numpy array: +# assert isinstance(labels, np.ndarray), "The labels were not of type np.array!" +# # Shuffle labels if needed: +# if shuffle_labels: +# perm_ind = np.random.permutation(len(labels)) +# labels = labels[perm_ind] + +# # Pre-allocating for the diagonal RDM: +# rdm_diag = [] +# if n_folds is None: +# # Preallocating for the rsa: +# rsa_matrix = np.zeros((data.shape[-1], data.shape[-1])) +# for t1 in range(0, data.shape[-1]): +# # Get the data at t1 from train set: +# d1 = np.squeeze(data[:, :, t1]) + +# # Now looping through every other time point: +# for t2 in range(0, data.shape[-1]): +# # Get the data at t2 from the test set: +# d2 = np.squeeze(data[:, :, t2]) + +# # Compute the RDM: +# rdm = cdist(d1, d2, metric) +# if metric == "correlation" and fisher_transform: +# # Performing the fisher transformation of the correlation values (converting distances to +# # correlation, fisher transform, back into distances): +# rdm = 1 - np.arctanh(1 - rdm) + +# # If we are along the diagona, store the rdm: +# if t1 == t2: +# rdm_diag.append(rdm) + +# # Compute the within class correlated distances:: +# msk_within = np.meshgrid(labels, labels)[1] == \ +# np.meshgrid(labels, labels)[0] +# msk_between = np.meshgrid(labels, labels)[1] != \ +# np.meshgrid(labels, labels)[0] +# np.fill_diagonal(msk_within, False) +# np.fill_diagonal(msk_between, False) +# within_val = rdm[msk_within] +# across_val = rdm[msk_between] +# # Finally, computing the correlation between the rsa_matrix at t1 and t2: +# if n_bootsstrap is not None: +# if len(within_val) != len(across_val): +# # Find the minimal samples between the within and across: +# min_samples = min([len(within_val), len(across_val)]) +# bootstrap_diff = [] +# for n in range(n_bootsstrap): +# bootstrap_diff.append(np.mean(np.random.choice(across_val, min_samples, replace=False)) - +# np.mean(np.random.choice(within_val, min_samples, replace=False))) +# rsa_matrix[t1, t2] = np.mean(bootstrap_diff) +# else: +# rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) +# else: +# rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) +# else: +# # Using stratified kfold to perform feature selection: +# skf = StratifiedKFold(n_splits=n_folds) +# folds_mat = [] +# # Store a list of the features that were used: +# sel_features = [] +# # Splitting the data in nfolds, selecting features on one fold and testing on the rest: +# for test_ind, feat_sel_ind in skf.split(data, labels): +# # Preallocate for the rdm: +# rsa_matrix = np.zeros((data.shape[-1], data.shape[-1])) +# # Compute the cross temporal RDM: +# for t1 in range(0, data.shape[-1]): +# # Perform the feature selection on the test set +# features_sel = SelectKBest(f_classif, k=n_features).fit(data[feat_sel_ind, :, t1], labels[feat_sel_ind]) +# sel_features.append(features_sel.get_support(indices=True)) +# # Get the data at t1 from train set: +# d1 = np.squeeze(features_sel.transform(data[test_ind, :, t1])) + +# # Now looping through every other time point: +# for t2 in range(0, data.shape[-1]): +# # Get the data at t2 from the test set: +# d2 = np.squeeze(features_sel.transform(data[test_ind, :, t2])) + +# # Compute the RDM: +# rdm = cdist(d1, d2, metric) +# if metric == "correlation" and fisher_transform: +# # Performing the fisher transformation of the correlation values (converting distances to +# # correlation fisher transform, back into distances): +# rdm = 1 - np.arctanh(1 - rdm) + +# # Compute the within class correlated distances:: +# msk_within = np.meshgrid(labels[test_ind], labels[test_ind])[1] == \ +# np.meshgrid(labels[test_ind], labels[test_ind])[0] +# msk_between = np.meshgrid(labels[test_ind], labels[test_ind])[1] != \ +# np.meshgrid(labels[test_ind], labels[test_ind])[0] +# np.fill_diagonal(msk_within, False) +# np.fill_diagonal(msk_between, False) +# within_val = rdm[msk_within] +# across_val = rdm[msk_between] +# # Finally, computing the correlation between the rsa_matrix at t1 and t2: +# if n_bootsstrap is not None: +# if len(within_val) != len(across_val): +# # Find the minimal samples between the within and across: +# min_samples = min([len(within_val), len(across_val)]) +# bootstrap_diff = [] +# for n in range(n_bootsstrap): +# bootstrap_diff.append( +# np.mean(np.random.choice(across_val, min_samples, replace=False)) - +# np.mean(np.random.choice(within_val, min_samples, replace=False))) +# rsa_matrix[t1, t2] = np.mean(bootstrap_diff) +# else: +# rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) +# else: +# rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) +# # Append to the fold mat: +# folds_mat.append(rsa_matrix) +# # Average across folds: +# rsa_matrix = np.average(np.array(folds_mat), axis=0) +# # Compute the diagonal RDMs: +# rdm_diag = [cdist(data[:, :, t1], data[:, :, t1], metric) for t1 in range(0, data.shape[-1])] + +# return rsa_matrix, rdm_diag, sel_features + + +def all_to_all_within_class_dist(data, labels, metric="correlation", n_bootsstrap=20, shuffle_labels=False, + fisher_transform=True, verbose=False, n_features=None, n_folds=None, + feat_sel_diag=True): + """ + This function computes all trials to all trials distances and computes within class correlated distances in a + cross temporal fashion. + :param data: + :param labels: + :param metric: + :param n_bootsstrap: + :param shuffle_labels: + :param fisher_transform: + :param verbose: + :param n_features: + :param n_folds: + :param feat_sel_diag: + :return: + """ + if verbose: + print("=" * 40) + print("Welcome to cross_identity_cross_temp_rsm") + # Make sure the labels are a numpy array: + assert isinstance(labels, np.ndarray), "The labels were not of type np.array!" + # Shuffle labels if needed: + if shuffle_labels: + # Some label shuffles are ineffective, as they are swapping two trials of the same condition + # To avoid those, brut force approach: reshuffle until we are sure that at least 40% of the labels + # have been swapped: + ok = False + while not ok: + # Shuffle the labels: + new_lbl = labels[np.random.permutation(len(labels))] + # Check equality between original and shuffled labels: + if np.sum(labels != new_lbl) / len(labels) > 0.4: + ok = True + labels = new_lbl + + # Pre-allocating for the diagonal RDM: + rdm_diag = [] + if n_features is None or feat_sel_diag: + sel_features = [] + # Preallocating for the rsa: + rsa_matrix = np.zeros((data.shape[-1], data.shape[-1])) + for t1 in range(0, data.shape[-1]): + # Get the data at t1 from train set: + d1 = np.squeeze(data[:, :, t1]) + if n_features is not None: + # Perform the feature selection on the test set + features_sel = SelectKBest(f_classif, k=n_features).fit(d1, labels) + sel_features.append(features_sel.get_support(indices=True)) + # Now looping through every other time point: + for t2 in range(0, data.shape[-1]): + # Get the data at t2 from the test set: + d2 = np.squeeze(data[:, :, t2]) + + # Compute the RDM: + if n_features is not None: + rdm = cdist(features_sel.transform(d1), features_sel.transform(d2), metric) + else: + rdm = cdist(d1, d2, metric) + + if metric == "correlation" and fisher_transform: + # Performing the fisher transformation of the correlation values (converting distances to + # correlation, fisher transform, back into distances): + rdm = 1 - np.arctanh(1 - rdm) + + # If we are along the diagona, store the rdm: + if t1 == t2: + rdm_diag.append(rdm) + + # Compute the within class correlated distances:: + msk_within = np.meshgrid(labels, labels)[1] == \ + np.meshgrid(labels, labels)[0] + msk_between = np.meshgrid(labels, labels)[1] != \ + np.meshgrid(labels, labels)[0] + np.fill_diagonal(msk_within, False) + np.fill_diagonal(msk_between, False) + within_val = rdm[msk_within] + across_val = rdm[msk_between] + # Finally, computing the correlation between the rsa_matrix at t1 and t2: + if n_bootsstrap is not None: + if len(within_val) != len(across_val): + # Find the minimal samples between the within and across: + min_samples = min([len(within_val), len(across_val)]) + bootstrap_diff = [] + for n in range(n_bootsstrap): + bootstrap_diff.append(np.mean(np.random.choice(across_val, min_samples, replace=False)) - + np.mean(np.random.choice(within_val, min_samples, replace=False))) + rsa_matrix[t1, t2] = np.mean(bootstrap_diff) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + else: + # Using stratified kfold to perform feature selection: + skf = StratifiedKFold(n_splits=n_folds) + folds_mat = [] + #Store a list of the features that were used: + sel_features = [] + # Splitting the data in nfolds, selecting features on one fold and testing on the rest: + for test_ind, feat_sel_ind in skf.split(data, labels): + # Preallocate for the rdm: + rsa_matrix = np.zeros((data.shape[-1], data.shape[-1])) + # Compute the cross temporal RDM: + for t1 in range(0, data.shape[-1]): + # Perform the feature selection on the test set + features_sel = SelectKBest(f_classif, k=n_features).fit(data[feat_sel_ind, :, t1], labels[feat_sel_ind]) + sel_features.append(features_sel.get_support(indices=True)) + # Get the data at t1 from train set: + d1 = np.squeeze(features_sel.transform(data[test_ind, :, t1])) + + # Now looping through every other time point: + for t2 in range(0, data.shape[-1]): + # Get the data at t2 from the test set: + d2 = np.squeeze(features_sel.transform(data[test_ind, :, t2])) + + # Compute the RDM: + rdm = cdist(d1, d2, metric) + if metric == "correlation" and fisher_transform: + # Performing the fisher transformation of the correlation values (converting distances to + # correlation fisher transform, back into distances): + rdm = 1 - np.arctanh(1 - rdm) + + # Compute the within class correlated distances:: + msk_within = np.meshgrid(labels[test_ind], labels[test_ind])[1] == \ + np.meshgrid(labels[test_ind], labels[test_ind])[0] + msk_between = np.meshgrid(labels[test_ind], labels[test_ind])[1] != \ + np.meshgrid(labels[test_ind], labels[test_ind])[0] + np.fill_diagonal(msk_within, False) + np.fill_diagonal(msk_between, False) + within_val = rdm[msk_within] + across_val = rdm[msk_between] + # Finally, computing the correlation between the rsa_matrix at t1 and t2: + if n_bootsstrap is not None: + if len(within_val) != len(across_val): + # Find the minimal samples between the within and across: + min_samples = min([len(within_val), len(across_val)]) + bootstrap_diff = [] + for n in range(n_bootsstrap): + bootstrap_diff.append( + np.mean(np.random.choice(across_val, min_samples, replace=False)) - + np.mean(np.random.choice(within_val, min_samples, replace=False))) + rsa_matrix[t1, t2] = np.mean(bootstrap_diff) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + # Append to the fold mat: + folds_mat.append(rsa_matrix) + # Average across folds: + rsa_matrix = np.average(np.array(folds_mat), axis=0) + # Compute the diagonal RDMs: + rdm_diag = [cdist(data[:, :, t1], data[:, :, t1], metric) for t1 in range(0, data.shape[-1])] + + return rsa_matrix, rdm_diag, sel_features + +def within_vs_between_cross_temp_rsa_alex(data, labels, + metric="correlation", n_bootsstrap=100, zscore=False, + sample_rdm_times=None, + onset_offset=None, n_features=40, n_folds=5, + shuffle_labels=False, fisher_transform=True, + verbose=False, feat_sel_diag=True, store_intermediate=False): + """ + This function computes cross temporal RDM by computing distances between two repetitions of the same trial + identities as well as between different identities at different time points. Then, the difference is computed + between the within (i.e. diagonal) and between (off diagonal condition). This results in one metric per time x time + pixel that summarizes the extent to which identity is conserved between different time point. + :param onset_offset: + :param data: (numpy array) contains the data with shape trials x channels x time points for which to compute the + rsa. + :param labels: (numpy array) contains the labels of the first dimension of the data + :param metric: (string) metric to use to compute the distances. See from scipy.spatial.distance import cdist for + options + :param n_bootsstrap: (int) number of bootstrap in case the number of samples in within class differs from across + classes + :param zscore: (boolean) whether or not to zscore the data at each distance computations + :param sample_rdm_times: (list of time points) compute a sample RDM at this time points. This enables getting a + :param n_features: (int) number of features to select. The features are selected in a quite complicated way. + We are basically splitting the data such that the features are always selected on a different set of data than what + is used to compute the correlation. + :param n_folds: (int) number of folds for the feature selection + :param shuffle_labels: (bool) wheher or not to shuffle the labels + :param fisher_transform: (bool) whehther or not to fisher transform the correlation value before computing within + vs between + :param verbose: (bool) whether or not to print additional info to command line + :param feat_sel_diag: (bool) whether to perform the feature selection on the diagonal only. If yes, then performing + the feature selection only once on the "train set" for each time point along the "y axis". + :return: + rsa_matrix: (numpy array) cross temporal rsa matrix that was computed + sample_rdm: (numpy array) representation dissimilarity matrix according to the time points passed under + sample_rdm_times + """ + # Deactivate the warnings, because in some cases, scikit learn will send so many warnings that the logs get + # completely overcrowded: + import warnings + warnings.filterwarnings('ignore') + if metric != "correlation" and fisher_transform: + if verbose: + print("WARNING: fisher transform only applies for correlation!") + if onset_offset is None: + onset_offset = [-0.3, 1.5] + if sample_rdm_times is None: + sample_rdm_times = [0.2, 0.4] + if verbose: + print("=" * 40) + print("Welcome to cross_identity_cross_temp_rsm") + # Make sure the labels are a numpy array: + assert isinstance(labels, np.ndarray), "The labels were not of type np.array!" + # Shuffle labels if needed: + if shuffle_labels: + perm_ind = np.random.permutation(len(labels)) + labels = labels[perm_ind] + # Use scikit learn cross validation to compute the split between "first" and "second" presentation: + skf = StratifiedKFold(n_splits=2) + # Preallocating for the rsa: + rsa_matrix = np.zeros((data.shape[-1], data.shape[-1])) + if store_intermediate: + within_class_mean = np.zeros((data.shape[-1], data.shape[-1])) + within_class_sd = np.zeros((data.shape[-1], data.shape[-1])) + between_class_mean = np.zeros((data.shape[-1], data.shape[-1])) + between_class_sd = np.zeros((data.shape[-1], data.shape[-1])) + else: + within_class_mean = [] + within_class_sd = [] + between_class_mean = [] + between_class_sd = [] + + # Extract the indices of first and second pres_ + first_pres_ind, second_pres_ind = list(skf.split(data, labels))[0] + # Extract the label of each: + first_pres_labels = labels[first_pres_ind] + second_pres_labels = labels[second_pres_ind] + # Store a list of the features that were used: + sel_features = [] + if verbose: + print("=" * 40) + print("Welcome to cross_identity_cross_temp_rsm") + print("First presentation:", first_pres_labels, "\nSecond presentation:", second_pres_labels) + print("Computing representation similarity between all time points") + for t1 in range(0, data.shape[-1]): + # Get the data at t1 from train set: + d1 = np.squeeze(data[first_pres_ind, :, t1]) + if zscore: + scaler = StandardScaler() + scaler.fit(d1) + d1 = scaler.transform(d1) + # If the features selection is done along the diagonal only: + if n_features is not None and feat_sel_diag: + # Extract the features on the first split of the data at the current time point: + features_sel = SelectKBest(f_classif, k=n_features).fit(d1, first_pres_labels) + sel_features.append(features_sel.get_support(indices=True)) + # Now looping through every other time point: + for t2 in range(0, data.shape[-1]): + # Get the data at t2 from the test set: + d2 = np.squeeze(data[second_pres_ind, :, t2]) + if zscore: + d2 = scaler.transform(d2) + # Compute the distance between all combinations of trials: + # Selecting features if needed: + if n_features is not None and not feat_sel_diag: + # Prepare the rdm matrix: + rdm = np.zeros([d1.shape[0], d2.shape[0]]) + # Selecting features with a cross validation to avoid double dipping: + # Counting the number of events per halves: + first_pres_labels_cts = Counter(first_pres_labels) + second_pres_labels_cts = Counter(second_pres_labels) + # If each item occurs less often than there are folds, then using k fold: + if all(i < n_folds for i in list(first_pres_labels_cts.values())) \ + or all(i < n_folds for i in list(second_pres_labels_cts.values())): + f_d1 = KFold(n_splits=n_folds) + f_d2 = KFold(n_splits=n_folds) + else: + f_d1 = StratifiedKFold(n_splits=n_folds) + f_d2 = StratifiedKFold(n_splits=n_folds) + # Looping through the folds of d1: + for train_d1, test_d1 in f_d1.split(d1, first_pres_labels): + # Looping through d2 folds: + for train_d2, test_d2 in f_d2.split(d2, second_pres_labels): + # Extract the data for the feature selection: + feature_sel_data = np.concatenate([d1[train_d1, :], d2[train_d2, :]], axis=0) + feature_sel_labels = np.concatenate([first_pres_labels[train_d1], + second_pres_labels[train_d2]], + axis=0) + features_sel = SelectKBest(f_classif, k=n_features).fit(feature_sel_data, feature_sel_labels) + d1_test = features_sel.transform(d1[test_d1, :]) + d2_test = features_sel.transform(d2[test_d2, :]) + sub_rdm = cdist(d1_test, d2_test, metric) + for ind_1, rdm_row in enumerate(test_d1): + for ind_2, rdm_col in enumerate(test_d2): + rdm[rdm_row, rdm_col] = sub_rdm[ind_1, ind_2] + elif n_features is not None and feat_sel_diag: + # Compute the rdm based on the feature selection performed on this data: + rdm = cdist(features_sel.transform(d1), features_sel.transform(d2), metric) + else: + rdm = cdist(d1, d2, metric) + if metric == "correlation" and fisher_transform: + # Performing the fisher transformation of the correlation values (converting distances to correlation, + # fisher transform, back into distances): + rdm = 1 - np.arctanh(1 - rdm) + + # Create a mask with values == true for within condition, false otherwise: + msk = np.meshgrid(second_pres_labels, first_pres_labels)[0] == \ + np.meshgrid(second_pres_labels, first_pres_labels)[1] + within_val = rdm[msk] + across_val = rdm[~msk] + # Finally, computing the correlation between the rsa_matrix at t1 and t2: + if n_bootsstrap is not None: + if len(within_val) != len(across_val): + # Find the minimal samples between the within and across: + min_samples = min([len(within_val), len(across_val)]) + bootstrap_diff = [] + if store_intermediate: + bootstrap_within_mean = [] + bootstrap_within_sd = [] + bootstrap_between_mean = [] + bootstrap_between_sd = [] + for n in range(n_bootsstrap): + bootstrap_diff.append(np.mean(np.random.choice(across_val, min_samples, replace=False)) - + np.mean(np.random.choice(within_val, min_samples, replace=False))) + if store_intermediate: + bootstrap_within_mean.append(np.mean(np.random.choice(within_val, min_samples, + replace=False))) + bootstrap_within_sd.append(np.std(np.random.choice(within_val, min_samples, + replace=False))) + bootstrap_between_mean.append(np.mean(np.random.choice(across_val, min_samples, + replace=False))) + bootstrap_between_sd.append(np.std(np.random.choice(across_val, min_samples, + replace=False))) + rsa_matrix[t1, t2] = np.mean(bootstrap_diff) + if store_intermediate: + within_class_mean[t1, t2] = np.mean(bootstrap_within_mean) + within_class_sd[t1, t2] = np.mean(bootstrap_within_sd) + between_class_mean[t1, t2] = np.mean(bootstrap_between_mean) + between_class_sd[t1, t2] = np.mean(bootstrap_between_sd) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + if store_intermediate: + within_class_mean[t1, t2] = np.mean(within_val) + within_class_sd[t1, t2] = np.std(within_val) + between_class_mean[t1, t2] = np.mean(across_val) + between_class_sd[t1, t2] = np.std(across_val) + else: + rsa_matrix[t1, t2] = np.mean(across_val) - np.mean(within_val) + if store_intermediate: + within_class_mean[t1, t2] = np.mean(within_val) + within_class_sd[t1, t2] = np.std(within_val) + between_class_mean[t1, t2] = np.mean(across_val) + between_class_sd[t1, t2] = np.std(across_val) + + # Finally, computing a sample RDM: + if sample_rdm_times is not None: + time_vec = np.around(np.linspace(onset_offset[0], onset_offset[1], num=data.shape[-1]), decimals=3) + # Find the samples that correspond to the time window: + onset_ind = np.where(time_vec >= sample_rdm_times[0])[0][0] + offset_ind = np.where(time_vec <= sample_rdm_times[1])[0][-1] + 1 # adding one to take the last point in + # Get the data of that time point: + data_win = data[:, :, onset_ind:offset_ind] + # If there are several time points in this window, averaging across them: + if data_win.shape[-1] > 1: + data_win = np.mean(data_win, axis=-1) + else: + data_win = np.squeeze(data_win) + # Prepare the data and compute the RDM: + d1 = np.squeeze(data_win[first_pres_ind, :]) + # For the second repetition, same but with an offset of 1: + d2 = np.squeeze(data_win[second_pres_ind, :]) + # Also doing feature selection if needed for the sample rdm: + if n_features is not None and not feat_sel_diag: + # Prepare the rdm matrix: + sample_rdm = np.zeros([d1.shape[0], d2.shape[0]]) + # Selecting features with a cross validation to avoid double dipping: + # Counting the number of events per halves: + first_pres_labels_cts = Counter(first_pres_labels) + second_pres_labels_cts = Counter(second_pres_labels) + # Using stratified k fold to split d1 and d2: + if all(i < n_folds for i in list(first_pres_labels_cts.values())) \ + or all(i < n_folds for i in list(second_pres_labels_cts.values())): + f_d1 = KFold(n_splits=n_folds) + f_d2 = KFold(n_splits=n_folds) + else: + f_d1 = StratifiedKFold(n_splits=n_folds) + f_d2 = StratifiedKFold(n_splits=n_folds) + # Looping through the folds of d1: + for train_d1, test_d1 in f_d1.split(d1, first_pres_labels): + # Looping through d2 folds: + for train_d2, test_d2 in f_d2.split(d2, second_pres_labels): + # Extract the data for the feature selection: + feature_sel_data = np.concatenate([d1[train_d1, :], d2[train_d2, :]], axis=0) + feature_sel_labels = np.concatenate([first_pres_labels[train_d1], + second_pres_labels[train_d2]], + axis=0) + features_sel = SelectKBest(f_classif, k=n_features).fit(feature_sel_data, feature_sel_labels) + d1_test = features_sel.transform(d1[test_d1, :]) + d2_test = features_sel.transform(d2[test_d2, :]) + sub_rdm = cdist(d1_test, d2_test, metric) + for ind_1, rdm_row in enumerate(test_d1): + for ind_2, rdm_col in enumerate(test_d2): + sample_rdm[rdm_row, rdm_col] = sub_rdm[ind_1, ind_2] + elif n_features is not None: + # Extract the features on the first split of the data at the current time point: + features_sel = SelectKBest(f_classif, k=n_features).fit(d1, first_pres_labels) + # Compute the RDM selecting these features: + sample_rdm = cdist(features_sel.transform(d1), features_sel.transform(d2), metric) + else: + # Otherwise, computing the rdm on the data: + sample_rdm = cdist(d1, d2, metric) + # Sorting the matrix for it to make sense: + row_ind, col_ind = np.argsort(first_pres_labels), np.argsort(second_pres_labels) + sample_rdm = sample_rdm[row_ind, :][:, col_ind] + else: + sample_rdm = None + + return rsa_matrix, sample_rdm, sel_features + + +def rdm_regress_groups(rdm, groups_1, groups_2): + """ + This function regresses out the group effect from the rdm. The group effect is encoded through the two groups + variables here. Groups_1 matches the 1 dim of the rdm and groups 2 the 2nd dim and identified which group a trial + belonged to. And so if the group of group 1 and 2 match at a given intersection, then this is a within group and if + they don't across. Within is encoded as 1 and across as a zero. The matrix is then flattened and regressed out from + the rdm. In other words: + face, face, object, object + face, 1 1 0 0 + face, 1 1 0 0 + object, 1 1 0 0 + object 1 1 0 0 + :param rdm: (2D numpy array) rdm from which the group should be regressed. The first dimension corresponds to the + first set of trials to compute the rdm and the 2d the second set of trials + :param groups_1: (numpy array) group to which the first set of trials belong to + :param groups_2: (numpy array) group to which the second set of trials belong to + :return: + rdm: the same rdm but with the group information regressed out + """ + + # Generating the groups matrix: + groups_regressor = np.zeros(rdm.shape) + for i in range(groups_regressor.shape[0]): + for ii in range(groups_regressor.shape[1]): + if groups_1[i] == groups_2[ii]: + groups_regressor[i, ii] = 1 + else: + groups_regressor[i, ii] = 0 + + # Flatten the two matrices: + rdm_flat = rdm.flatten() + groups_regressor_flat = groups_regressor.flatten() + + # Adding a couple of checks just to be sure the reshape is never messed up. My understanding is that it can't be + # messed up the way I do it, but that way it really can't be! + np.testing.assert_array_equal(rdm_flat.reshape(rdm.shape), rdm) + np.testing.assert_array_equal(groups_regressor_flat.reshape(groups_regressor.shape), groups_regressor) + + # Regress teh groups regressor out: + rdm_regress_flat = sm.OLS(rdm_flat, groups_regressor_flat).fit().resid + + # Finally, reconverting it to a square matrix: + return rdm_regress_flat.reshape(rdm.shape) + + +def compute_correlation_theories(observed_matrix, theories_matrices, method="kendall"): + """ + Compute the correlation between the predicted and obtained matrices + :param observed_matrix: (list of np arrays) contains the decoding matrices + :param theories_matrices: (dict of arrays) contains the theories predicted matrices + :param method: (string) method to use to compute the correlation. Three options supported: pearson, spearman and + partial_correlation. In the partial correlation, one of the theory matrix will be used as a correlate + :return: (pd data frame) contains the correlation between the matrices + """ + print("-" * 40) + print("Welcome to compute_correlation_theories") + print("Computing {0} correlation between data and {1} predicted matrices". + format(method.lower(), + list(theories_matrices.keys()))) + supported_method = ["kendall", "spearman", "partial", "semi-partial"] + # Flatten the decoding scores and theory matrices: + observed_matrix_flat = [observed_matrix[i].flatten() + for i in range(0, len(observed_matrix))] + theory_matrices_flat = {theory: theories_matrices[theory].flatten( + ) for theory in theories_matrices.keys()} + if method.lower() == "kendall" or method.lower() == "spearman": + # Computing the correlation coefficient between the matrices of each cross validation folds: + correlation_results = pd.DataFrame({ + theory: [pg.corr(observed_matrix_flat[i], theory_matrices_flat[theory], + method=method.lower())["r"].item() + for i in range(0, len(observed_matrix_flat))] + for theory in theories_matrices.keys() + }) + elif method.lower() == "partial": + # For partial correlation, we need to convert the data to data frames: + theories = list(theories_matrices.keys()) + data_dfs = [pd.DataFrame({ + "scores": observed_matrix_flat[i], + theories[0]: theory_matrices_flat[theories[0]], + theories[1]: theory_matrices_flat[theories[1]]}) + for i in range(0, len(observed_matrix))] + # We now perform the partial correlation always holding one theory constant while checking the other: + correlation_results = {} + for ind, theory in enumerate(theories): + correlation_results[theory] = [pg.partial_corr(data=data_dfs[i], x='scores', y=theory, + covar=theories[ind - 1])["r"].item() + for i in range(len(data_dfs))] + # Convert the dict to a dataframe to keep things consistent: + correlation_results = pd.DataFrame(correlation_results) + elif method.lower() == "semi-partial": + # For partial correlation, we need to convert the data to data frames: + theories = list(theories_matrices.keys()) + data_dfs = [pd.DataFrame({ + "scores": observed_matrix_flat[i], + theories[0]: theory_matrices_flat[theories[0]], + theories[1]: theory_matrices_flat[theories[1]]}) + for i in range(0, len(observed_matrix))] + # We now perform the partial correlation always holding one theory constant while checking the other: + correlation_results = {} + for ind, theory in enumerate(theories): + correlation_results[theory] = [pg.partial_corr(data=data_dfs[i], x='scores', y=theory, + y_covar=theories[ind - 1])["r"].item() + for i in range(len(data_dfs))] + # Convert the dict to a dataframe to keep things consistent: + correlation_results = pd.DataFrame(correlation_results) + else: + raise Exception("You have passed {0} as correlation method, but only {1} supported".format(method.lower(), + supported_method)) + + # Correct the correlation results to be positively defined between 0 and 1: + correlation_results_corrected = correlation_results.apply(lambda x: (x + 1) / 2) + + return correlation_results, correlation_results_corrected + + +def subsample_matrices(matrix, start_time, end_time, intervals_of_interest): + """ + This function extracts bits of a bigger matrix and puts it back together. It is basically subselecting only the + bits of your matrix you care about. This enables for ex to do rsa only on bits of the temporal generalization matrix + as opposed to all of it. The coordinate of the times of interest is a bit complicated. It is a dictionary + containing two keys: x and y. x constitutes the "columns", while "y" constitute the rows. Because we are subsampling + several squares of a bigger matrix (though subselecting only 1 square would be a just a specific case), the idea + is that we can loop through the columns and then within that loop through the rows to make sure we don't mess up the + order. Imagine you have a matrix like below and want to subsample the squares within it + + X + _______________________________________________________________________ + | ____ ____ | + | | | | | | + | | 1 | | 3 | | + | |____| |____| | + | | + | | + | ____ ____ | Y + | | | | | | + | | 2 | | 4 | | + | |____| |____| | + | | + | | + |______________________________________________________________________| + + The idea is that we will have the outer loop be looping through the x intervals and the inner loop looping through + the Y. We can then stack vertically in the inner loop and horizontally in the outer loop, to be sure the order + doesn't get all messed up. I.e. we first sample 1, then 2 and stack them vertically, same for 3 and 4 and then + we stack the two matrices horizontally and that's it + This means that the coordinates should be a list of x and y coordinates, but without repetition. I.e. you don't need + to pass the same x coordinates twice for the samples in the same row. + NOTE: THIS FUNCTION WILL ONLY WORK PROPERLY IF YOU SAMPLE UNIFORMLY IN THE X AND Y AXIS, YOU MUST HAVE A SQUARE + MATRIX IN THE END! + :param matrix: (2d np array) matrix to subsample + :param intervals_of_interest: (dict of list) coordinate in times of each bits of the matrix you want to extract. + Keys: x and y, see above for in depth explanation + :param start_time: (float) time of the start time point of the passed matrix + :param end_time: (float) time of the end time point of the passed matrix + :return: + subsampled_matrix (np.array): numpy array of the subsampled matrix according to what was expected + new_time_vect (np.array): the truncated time vector of only the remaining time points. + connection_indices (np.array): the coordinates within the subsample matrix in which there is the discontunity + (to plot later on) + sub_matrices_dict (dictionary): contains the subsampled squares but not stacked together. The x and y coordinates + in temporal generalization matrix corresponds to test and train times respectively. Therefore, creating keys + based on it. + """ + # Generate the fitting time vector: + time_vect = np.around(np.linspace(start_time, end_time, num=matrix.shape[0]), decimals=3) + # Check that the passed times of interest are compatible with this function. Because we are dealing with floats, + # need to convert to strings with 3 decimals. We assume that anything beyond 4 decimals isn't relevant given the + # scale is in seconds, anything beyond that would be nanosecs: + x_lengths = ["{:.3f}".format(x_coord[1] - x_coord[0]) for x_coord in intervals_of_interest["x"]] + y_lengths = ["{:.3f}".format(y_coord[1] - y_coord[0]) for y_coord in intervals_of_interest["y"]] + if len(set(x_lengths)) > 1 or len(set(y_lengths)) > 1: + raise Exception("You have passed times of interest with inconsistent x and y length. This doesn't work because " + "\nthen, the matrices that you want to concatenate won't be of the same sizes") + # Now computing the x and y length in samples + x_lengths_samples = [np.where(time_vect <= x_coord[1])[0][-1] + 1 - np.where(time_vect >= x_coord[0])[0][0] + for x_coord in intervals_of_interest["x"]] + y_lengths_samples = [np.where(time_vect <= y_coord[1])[0][-1] + 1 - np.where(time_vect >= y_coord[0])[0][0] + for y_coord in intervals_of_interest["y"]] + x_length = min(x_lengths_samples) + y_length = min(y_lengths_samples) + matrix_columns = [] + new_time_vect = [] + connection_indices = [] + sub_matrices_dict = {} + for col_ind, interval_x in enumerate(intervals_of_interest["x"]): + matrix_sample = [] + for row_ind, interval_y in enumerate(intervals_of_interest["y"]): + # Finding the time points corresponding to the start and end of the predicition + x1 = np.where(time_vect >= interval_x[0])[0][0] + x2 = x1 + x_length + y1 = np.where(time_vect >= interval_y[0])[0][0] + y2 = y1 + y_length + matrix_sample.append(matrix[y1:y2, x1:x2]) + new_time_vect.append(time_vect[y1:y2]) + if row_ind == 0: + connection_indices.append(len(matrix[y1:y2, x1:x2]) - 1) + else: + connection_indices.append( + connection_indices[-1] + len(matrix[y1:y2, x1:x2])) + # Also add to a dictionary, to avoid having to break it down again after wards if needed: + key = "Train_{}:{}-Test_{}:{}".format(interval_y[0], interval_y[1], interval_x[0], interval_x[1]) + sub_matrices_dict[key] = matrix[y1:y2, x1:x2] + # Stacking the matrices sample horizontally: + matrix_columns.append(np.concatenate(matrix_sample, axis=0)) + # Convert new time to a numpy array: + new_time_vect = np.concatenate(new_time_vect, axis=0) + # Removing repetitions: + new_time_vect = np.unique(new_time_vect) + # Same for the connection indices: + connection_indices = np.unique(connection_indices) + # Removing the last index, because we don't need it: + connection_indices = connection_indices[:-1] + # Finally, stacking the columns horizontally: + return np.concatenate(matrix_columns, axis=1), new_time_vect, connection_indices, sub_matrices_dict + + +def label_shuffle_test_2d(observed_values, permutation_values, p_value_thresh=0.05, fdr_correction="fdr_bh"): + """ + This function compares the distribution of observed values againt a null distribution generated by shuffling labels. + The oberved values must be of the same dimentions except that the permutation values must have an extra dimension, + representing the repetitions. + This function compares the decoding scores observed against the results obtained when shuffling the labels. For each + decoding score obtained (either time resolves or temporal generalization), its quantile along all the values + obtained through permutation is computed. If the quantile is inferior to the threshold, it is considered significant + :param observed_values: (np array) contains the observed decoding scores + :param permutation_values: (np array) contains the decoding scores obtained by shuffling the labels. + :param p_value_thresh: (float) p-value threshold for significance. + :param fdr_correction: which method to use for FDR + :return: diag_significance_mask: (np array of boolean) significance mask for the diagonal + matrix_sig_mask (np array of floats and nan) contains the scores values but only the ones which are significant. + """ + # Preallocate for storing the significance mask: + p_values = np.zeros(observed_values.shape) + # Generate the null distribution by concatenating the observed value to the permutation one: + null_distribution = np.concatenate([permutation_values, np.expand_dims(observed_values, axis=0)], axis=0) + # Now looping through each row and columns of the decoding matrix to compare obtained scores to the permutation + # scores: + for row_i in range(p_values.shape[0]): + for col_i in range(p_values.shape[1]): + # Find the position in the distribution of the obs value: + null = np.append(np.squeeze(null_distribution[:, row_i, col_i]), + observed_values[row_i, col_i]) + p_values[row_i, col_i] = _pval_from_histogram(np.array([observed_values[row_i, col_i]]), null, 1) + if fdr_correction is not None: + _, p_val_flat, _, _ = multitest.multipletests(p_values.flatten(), method=fdr_correction) + p_values = p_val_flat.reshape(p_values.shape) + # Binarize the significance matrix based on p_value: + sig_mask = p_values < p_value_thresh + # The significance mask has nan where we have non-significant values, and the actual value where they are + # significant + matrix_sig_mask = observed_values.copy() + matrix_sig_mask[~sig_mask] = np.nan + # Creating the significance flag: true if there are significant bits in the matrix: + significance_flag = np.any(sig_mask) + + return p_values, matrix_sig_mask, significance_flag + + +def equalize_label_counts(data, labels, groups=None): + """ + This function equalizes the counts of each labels, i.e. the number of trials per conditions. This function + assumes that the data are in the format trials x channels x time and labels must have the same dimension as + the first dimension as the data array + :param data: (numpy array) trials x channels x time + :param labels: (numpy array) label, i.e. condition of each trial + :param groups: (numpy array) same size as labels and tracks the groups a given label belongs to. + :return: + equalized_data: (np array) trials x channels x time with the same number of trials per condition + equalized_labels: (np array) trials with the same number of trials per condition + equalized_groups: (np array or none) + """ + # Equalizing the labels counts if needed: + equalized_data = [] + equalized_labels = [] + if groups is not None: + if groups.shape != labels.shape: + raise Exception("ERROR: The groups array shape does not match the labels array shape! " + "\ngroups: {} \nvs \nlabels: {}".format(groups.shape, labels.shape)) + equalized_groups = [] + else: + equalized_groups = None + # Get the minimal counts in the labels: + unique_labels, counts = np.unique(labels, return_counts=True) + min_label_counts = min(counts) + # Now, looping through every unique label to randomly pick the min: + for label in unique_labels: + # Get the index of this label: + label_ind = np.where(labels == label)[0] + # Randomly picking min: + picked_ind = label_ind[np.random.choice(label_ind.shape[0], min_label_counts, replace=False)] + # Fetching these data: + equalized_data.append(data[picked_ind, :, :]) + equalized_labels.extend(labels[picked_ind]) + if groups is not None: + equalized_groups.extend(groups[picked_ind]) + + # Concatenating things back together: + equalized_labels = np.array(equalized_labels) + equalized_data = np.concatenate(equalized_data, axis=0) + if groups is not None: + equalized_groups = np.array(equalized_groups) + + return equalized_data, equalized_labels, equalized_groups + +def equate_offset(epochs, cropping_dict): + """ + This function excise some time points from certain conditions that differ in durations to make the offset consistent + between different durations. This enables to crop out chunks of data at specified time points and sew the rest + back together. + :param epochs: (mne epochs) epochs to chop and sew back up. + :param cropping_dict: (dictionary) contains the info about what and when to crop: + "1500ms": { + "excise_onset": 1.0, + "excise_offset": 1.5 + } + The dictionary key corresponds to a specific experimental condition wished to be cropped, excise_onset to when to + start cropping and excise_offset when to stop + :return: + equated_epochs: mne epochs object with epochs cropped + """ + # Looping through each condition for which some time needs to be excised: + conds_epochs = [] + for cond in cropping_dict.keys(): + # Getting time and rounding at 3 decimals to avoid weird indexing issues + times = np.around(epochs.times, decimals=3) + # Get the data of that one condition: + #cond_epochs = epochs.copy()[cond] + cond_epochs = epochs.copy()['Duration in {}'.format([cond])] + # Now, get the time points to excise: + # The onset is the first point in the time vector that is superior or equal to the onset + excise_onset_ind = np.where(times >= cropping_dict[cond]["excise_onset"])[0][0] + # The offset is the last point in time that is inferior or equal to the offset. Need to add 1 to it, + # because in python, slicing doesn't take the last point (i.e. :n-1). But in the case where our offset is at + # 2 sec for ex, and the time vector goes from 0 to 2.5, then we want to take the point 2.0 in, not go only until + # 1.98 or something like that. + excise_offset_ind = np.where(times <= cropping_dict[cond]["excise_offset"])[0][-1] + 1 + print("Excising data from {} to {} from condition".format(times[excise_onset_ind], + times[excise_offset_ind - 1], + cond)) + # Excising: + cond_epochs_data = np.delete(cond_epochs.get_data(), range(excise_onset_ind, excise_offset_ind), axis=-1) + # Create a new epochs object: + conds_epochs.append(mne.EpochsArray(cond_epochs_data, cond_epochs.info, events=cond_epochs.events, + tmin=cond_epochs.times[0], event_id=cond_epochs.event_id, + baseline=cond_epochs.baseline, + metadata=cond_epochs.metadata, on_missing="warn")) + # Combining the epochs data: + equated_epochs = mne.concatenate_epochs(conds_epochs, add_offset=False) + + return equated_epochs + + +def regress_evoked(epochs): + """ + This function computes the evoked responses and regresses it out from every single trial + :param epochs: (mne epochs object) epochs from which the evoked should be regressed + :return: (mne epochs object) epochs from which the evoked response is regressed from + """ + print("=" * 40) + print("Welcome to regress_evoked") + # Compute the evoked: + evoked = epochs.average() + # Extracting the data from the mne objects: + epochs_data = epochs.get_data() + evoked_data = evoked.get_data() + print("Regressing evoked response out of every trial per channel") + for channel in range(epochs_data.shape[1]): + ch_evk = evoked_data[channel, :] + for trial in range(epochs_data.shape[0]): + epochs_data[trial, channel, :] = sm.OLS(epochs_data[trial, channel, :], ch_evk).fit().resid + # Packaging everything back into an mne epochs object: + epochs_regress = mne.EpochsArray(epochs_data, epochs.info, events=epochs.events, + tmin=epochs.times[0], event_id=epochs.event_id, baseline=epochs.baseline, + metadata=epochs.metadata, on_missing="warn") + return epochs_regress + + +def create_prediction_matrix(start, end, predicted_intervals, matrix_size): + """ + This function generates binary matrix of zero and ones according to theories predictions. This can then later be + compared to the results of the decoding. 1 is for when a theory predicts above chance decoding, 0 for no decoding + :param start: (float or int) start time of the decoding matrix. In secs + :param end: (float or int) end time of the decoding matrix. In secs + :param predicted_intervals: (dict of list of floats) contains the predicted onsets and offsets of above chance + decoding + in the start to end time vector. The format is like so: + { + "x": [[0.3, 0.5], [0.8, 1.5]...], + "y": [[0.3, 0.5], [0.8, 1.5]...], + } + IMPORTANT: The x and y must both have the same number of entries! + :param matrix_size: (int) size of the matrix tio generate. Must be the same size as the decoding matrix to compare + it to. + :return: (dict) predicted_matrix binary matrix containing predicted significant decoding + """ + if len(predicted_intervals["x"]) != len(predicted_intervals["y"]): + raise Exception("The x and y coordinates of the predicted intervals have different lengths! That doesn't work!") + # Create a matrix of zeros of the correct size: + predicted_matrix = np.zeros((matrix_size, matrix_size)) + # Generating a time vector matching the matrix size: + time_vect = np.around(np.linspace(start, end, num=matrix_size), decimals=3) + # The theories make prediction such that there will be decoding within specific time windows. Looping through those: + for ind, interval_x in enumerate(predicted_intervals["x"]): + # Finding the time points corresponding to the start and end of the prediction + # The onset is the first point in the time vector that is superior or equal to the onset + onset_x = np.where(time_vect >= interval_x[0])[0][0] + # The offset is the last point in time that is inferior or equal to the offset. Need to add 1 to it, + # because in python, slicing doesn't take the last point (i.e. :n-1). But in the case where our offset is at + # 2 sec for ex, and the time vector goes from 0 to 2.5, then we want to take the point 2.0 in, not go only until + # 1.98 or something like that. + offset_x = np.where(time_vect <= interval_x[1])[0][-1] + 1 + # Same for y: + onset_y = np.where(time_vect >= predicted_intervals["y"][ind][0])[0][0] + offset_y = np.where(time_vect <= predicted_intervals["y"][ind][1])[0][-1] + 1 + # Setting all these samples to 1: + if not isinstance(predicted_intervals["predicted_vals"][ind], str): + predicted_matrix[onset_y:offset_y, onset_x:offset_x] = predicted_intervals["predicted_vals"][ind] + elif predicted_intervals["predicted_vals"][ind].lower() == "nan": + predicted_matrix[onset_y:offset_y, onset_x:offset_x] = np.nan + else: + raise Exception("The predicted value must be either a number (float or int) or nan, check spelling!!") + + return predicted_matrix + + +def remove_too_few_trials(epochs, condition="identity", min_n_repeats=2, verbose=False): + """ + This function removes the conditions for which there are less than min_n_repeats. So say you only want conditions + for which you have at least 2 repeats, set min_n_repeats to 2. + :param epochs: (mne epochs object) contains the data and metadata to remove conditions from + :param condition: (string) name of the condition for which to equate. The string must match a column in the metadata + :param min_n_repeats: (int) minimal number of repeats a condition must have to pass! + :param verbose: (bool) whether or not to print information to the command line + :return: + epochs: (mne epochs object) the mne object with equated trials. Note that the function modifies data in place! + """ + if verbose: + print("Equating trials by downsampling {}".format(condition)) + # Get the meta data for that subject: + sub_metadata = epochs.metadata.reset_index(drop=True) + # Find the identity for which we have less than two trials: + cts = sub_metadata.groupby([condition])[condition].count() + id_to_remove = [identity for identity in cts.keys() if cts[identity] < min_n_repeats] + if verbose: + print("The following identity have less than two repetitions") + print(id_to_remove) + # Get the indices of the said identity to drop the trials: + id_idx = sub_metadata.loc[sub_metadata[condition].isin(id_to_remove)].index.values.tolist() + # Dropping those: + epochs.drop(id_idx, verbose="error") + return epochs + + diff --git a/roi_mvpa/sublist.py b/roi_mvpa/sublist.py new file mode 100644 index 0000000..07dfa57 --- /dev/null +++ b/roi_mvpa/sublist.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 29 19:02:19 2022 + +@author: Ling_BLCU +""" + +# -*- coding: utf-8 -*- +""" +=========== +Subject list file +=========== + +Configurate the parameters of the subject. + +""" + +# ============================================================================= +# subject_info +# ============================================================================= + +# subject_list +sub_a_list = ['SA102', 'SA103', 'SA104', 'SA111', 'SA114', + 'SA118','SA121', 'SA123', 'SA125', 'SA132', + 'SA133','SA134','SA136','SA138', 'SA139', + 'SA140','SA144','SA145','SA146', 'SA147', + 'SA148','SA150','SA151','SA154', 'SA158', + 'SA163','SA166','SA167','SA169', 'SA170', + 'SA173','SA174','SA176'] +sub_b_list = ['SB001','SB003','SB006','SB008','SB009', + 'SB011','SB012','SB016','SB019','SB020', + 'SB023','SB027','SB028','SB029','SB031', + 'SB035','SB036','SB039','SB040','SB042', + 'SB044','SB049','SB051','SB056','SB060', + 'SB063','SB069','SB072','SB074','SB081', + 'SB084','SB999'] +sub_list = sub_a_list + sub_b_list \ No newline at end of file diff --git a/roi_mvpa/sublist_phase2.py b/roi_mvpa/sublist_phase2.py new file mode 100644 index 0000000..ee95f83 --- /dev/null +++ b/roi_mvpa/sublist_phase2.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 29 19:02:19 2022 + +@author: Ling_BLCU +""" + +# -*- coding: utf-8 -*- +""" +=========== +Subject list file +=========== + +Configurate the parameters of the subject. + +""" + +# ============================================================================= +# subject_info +# ============================================================================= + +# subject_list +sub_a_list = ['SA106', 'SA107', 'SA109', 'SA110','SA112', #, + 'SA113','SA116', 'SA124', 'SA126', 'SA127', + 'SA128','SA131','SA142','SA152','SA160', #, + 'SA172'] +sub_b_list = ['SB002','SB013','SB015','SB022', 'SB024', + 'SB030','SB038','SB041','SB045', 'SB050', #, + 'SB061','SB065', 'SB071','SB073','SB078', + 'SB085'] +sub_list = sub_a_list + sub_b_list \ No newline at end of file diff --git a/source_modelling/S00_bem.py b/source_modelling/S00_bem.py new file mode 100644 index 0000000..edfea87 --- /dev/null +++ b/source_modelling/S00_bem.py @@ -0,0 +1,231 @@ +""" +================= +S00. BEM (and coregistration) +================= + +Perform the automated coregistration: + +Step 1 - Visualize Freesurfer parcellation +Step 2 - MNE-python scalp surface reconstruction +Step 3 - Boundary Element Model (BEM) reconstruction +Step 4 - Get Boundary Element Model (BEM) +(Step 5 - Coregistration) + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +# import numpy as np +import argparse + +import mne + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA101', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--bids_root', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', + help='Path to the BIDS root directory') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +opt=parser.parse_args() + + +# Set params +subject = "sub-"+opt.sub +visit = opt.visit +subjects_dir = opt.fs_path +if visit == "V1": + fname_raw = op.join(opt.bids_root, subject, "ses-"+visit, "meg", subject+"_ses-V1_task-dur_run-01_meg.fif") +elif visit == "V2": + fname_raw = op.join(opt.bids_root, subject, "ses-"+visit, "meg", subject+"_ses-V2_task-vg_run-01_meg.fif") #TODO: to be tested +coreg_deriv_root = op.join(opt.bids_root, "derivatives", "coreg") +if not op.exists(coreg_deriv_root): + os.makedirs(coreg_deriv_root) +coreg_figure_root = op.join(coreg_deriv_root, + f"sub-{opt.sub}",f"ses-{visit}","meg", + "figures") +if not op.exists(coreg_figure_root): + os.makedirs(coreg_figure_root) + +# Step 1 - Freesurfer recontruction (only on Linux/MACos) +def viz_fs_recon(): + ''' + Freesurfer recontruction (only on Linux/MACos) + + Run the following command in a terminal: + > recon-all -i SA101.nii -s SA101 -all + For more info, go to https://surfer.nmr.mgh.harvard.edu/fswiki/recon-all/ + + To convert DICOM to NIFTI, use MRIcron + + ''' + # Visualize reconstruction: + Brain = mne.viz.get_brain_class() + brain = Brain(subject, + hemi='lh', + surf='pial', + subjects_dir=subjects_dir, + size=(800, 600)) + brain.add_annotation('aparc', borders=False) #aparc.a2009s + + # Save figure + fname_figure = op.join(subjects_dir, "fs_aparc.png") + brain.save_image(fname_figure) + + +# Step 2 - Scalp surface reconstruction +def make_scalp_surf(): + ''' + Scalp surface reconstruction + + Either use this function ot run the following commands in a terminal: + > mne make_scalp_surfaces --overwrite --subject SA101 --force + + + ''' + mne.bem.make_scalp_surfaces(subject, + subjects_dir=subjects_dir, + force=True, + overwrite=True, + verbose=True) + + +# Step 3 - Boundary Element Model (BEM) reconstruction +def make_bem(): + ''' + Boundary Element Model (BEM) + + To create the BEM, either use this function or run the following command + in a terminal (requires FreeSurfer): + > mne watershed_bem --overwrite --subject ${file} + + ''' + mne.bem.make_watershed_bem(subject, + subjects_dir=subjects_dir, + overwrite=True, + verbose=True) + + +# Step 4 - Get Boundary Element Model (BEM) solution +def get_bem(): + ''' + Make Boundary Element Model (BEM) solution + + Computing the BEM surfaces requires FreeSurfer and is done using the + following command: + > mne watershed_bem --overwrite --subject SA101 + + Once the BEM surfaces are read, create the BEM model + + ''' + # Create BEM model + conductivity = (0.3,) # for single layer + # conductivity = (0.3, 0.006, 0.3) # for three layers + model = mne.make_bem_model(subject, + ico=4, + conductivity=conductivity, + subjects_dir=subjects_dir) + + # Finally, the BEM solution is derived from the BEM model + bem = mne.make_bem_solution(model) + + # Save data + fname_bem = op.join(subjects_dir, subject, subject+"_ses-"+visit+"_bem-sol.fif") + mne.write_bem_solution(fname_bem, + bem, + overwrite=True) + # Visualize the BEM + fig = mne.viz.plot_bem(subject=subject, + subjects_dir=subjects_dir, + #brain_surfaces='white', + orientation='coronal') + + # Save figure + fname_figure = op.join(subjects_dir, subject, "bem-sol.png") + fig.savefig(fname_figure) + + return bem + + +# # Step 5 - Coregistration +# def coreg(): +# ''' +# Coregistration + +# Tutorial: https://www.slideshare.net/mne-python/mnepython-coregistration + +# To get the path of MNE sample data, run: +# > mne.datasets.sample.data_path() + +# Save fiducials as: +# SA101_MRI-fiducials + +# At the end of the coregistration, save the transformation matrix and +# rename the file following the naming convention (see example below) +# SA101-trans.fif + +# To open the coregistration GUI, run: +# > mne.gui.coregistration(subject=subject, subjects_dir=subjects_dir) +# or run "mne coreg" from the terminal + +# ''' +# # Automated coregistration +# info = mne.io.read_info(fname_raw) +# fiducials = "estimated" # get fiducials from fsaverage +# coreg = mne.coreg.Coregistration(info, subject, subjects_dir, fiducials=fiducials) + +# # Fit using 3 fiducial points +# coreg.fit_fiducials(verbose=True) + +# # Refine the transformation using the Iterative Closest Point (ICP) algorithm +# coreg.fit_icp(n_iterations=6, nasion_weight=2., verbose=True) + +# # Remove outlier points +# coreg.omit_head_shape_points(distance=5. / 1000) + +# # Do a final coregistration fit +# coreg.fit_icp(n_iterations=20, nasion_weight=10., verbose=True) + +# # Compute the distance error +# dists = coreg.compute_dig_mri_distances() * 1e3 # in mm +# print( +# f"Distance between HSP and MRI (mean/min/max):\n{np.mean(dists):.2f} mm " +# f"/ {np.min(dists):.2f} mm / {np.max(dists):.2f} mm" +# ) + +# # Save transformation matrix +# fname_trans = op.join(coreg_deriv_root, subject+"_ses-"+visit+"_trans.fif") +# mne.write_trans(fname_trans, coreg.trans) + +# # # Visualize the transformation #TODO: 3d plots don't work on the HPC +# # fig = mne.viz.plot_alignment(info, coreg.trans, subject=subject, dig=True, +# # meg=['helmet', 'sensors'], subjects_dir=subjects_dir, +# # surfaces='head-dense') + +# # # Save figure +# # fname_figure = op.join(coreg_figure_root, "coreg.png") +# # fig.savefig(fname_figure) + +# return coreg.trans + + +if __name__ == "__main__": + # viz_fs_recon() #TODO: 3d plots don't work on the HPC + make_scalp_surf() + make_bem() + bem = get_bem() + # coreg() + \ No newline at end of file diff --git a/source_modelling/S01_forward_model.py b/source_modelling/S01_forward_model.py new file mode 100644 index 0000000..fa57b46 --- /dev/null +++ b/source_modelling/S01_forward_model.py @@ -0,0 +1,173 @@ +""" +================= +08. Forward model +================= + +Compute the forward model + +Step 1 - Freesurfer recontruction +Step 2 - MNE-python scalp surface reconstruction +Step 3 - Get Boundary Element Model (BEM) +Step 4 - Coregistration +Step 5 - Compute source space +Step 6 - Forward Model + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +# import numpy as np +import argparse + +import mne + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA124', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--bids_root', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', + help='Path to the BIDS root directory') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--coreg_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', + help='Path to the coreg (derivative) directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +opt=parser.parse_args() + + +# Set params +subject = "sub-"+opt.sub +visit = opt.visit +space = opt.space + +subjects_dir = opt.fs_path +fname_coreg = op.join(opt.coreg_path, subject, "ses-"+visit, "meg") + +fpath_fw = op.join(opt.out_fw, subject, "ses-"+visit, "meg") +if not op.exists(fpath_fw): + os.makedirs(fpath_fw) + +# fpath_fig = op.join(fpath_fw, "figures") +# if not op.exists(fpath_fig): +# os.makedirs(fpath_fig) + + +# Step 1 - Compute source space +def make_source_space(space): + ''' + Compute source space + + Surface-based source space is computed using: + > mne.setup_source_space() + Volumetric source space is computed using: + > mne.setup_volume_source_space() + + ''' + if space == 'surface': + # Get surface-based source space + spacing='oct6' # 4098 sources per hemisphere, 4.9 mm spacing + src = mne.setup_source_space(subject, + spacing=spacing, + add_dist='patch', + subjects_dir=subjects_dir) + # Set filename + fname_src = '%s-surface%s_src.fif' % (subject, spacing) + elif space == 'volume': + # Get volumetric source space (BEM required) + surface = op.join(subjects_dir, subject, + 'bem', 'inner_skull.surf') + src = mne.setup_volume_source_space(subject, + subjects_dir=subjects_dir, + surface=surface, + mri='T1.mgz', + verbose=True) + # Set filename + fname_src = '%s-volume_src.fif' % (subject) + # Save source space + mne.write_source_spaces(op.join(subjects_dir,subject,fname_src), + src, + overwrite=True) + # Visualize source space and BEM + mne.viz.plot_bem(subject=subject, + subjects_dir=subjects_dir, + brain_surfaces='white', + src=src, + orientation='coronal') + # # Visualize sources in 3d space + # if space == 'surface': #TODO: doesn't work with volume space + # fig = mne.viz.plot_alignment(subject=subject, + # subjects_dir=subjects_dir, + # trans=trans, + # surfaces='white', + # coord_frame='head', + # src=src) + # mne.viz.set_3d_view(fig, azimuth=173.78, elevation=101.75, + # distance=0.35, focalpoint=(-0.03, 0.01, 0.03)) + return src + + +# Step 2 - Forward Model +def make_forward_model(src, task): + ''' + Forward Model + + ''' + + # Set path to raw FIF + fname_raw = op.join(opt.bids_root, subject, "ses-"+visit, "meg", subject+"_ses-"+visit+"_task-"+task+"_run-01_meg.fif") + + + # Set transformation matrix and bem pathes + trans = op.join(fname_coreg, subject+"_ses-"+visit+"_trans.fif") + bem = op.join(subjects_dir, subject, subject+"_ses-V1_bem-sol.fif") #BEM is shared between sessions + + # Calculate forward solution for MEG channels + fwd = mne.make_forward_solution(fname_raw, + trans=trans, + src=src, + bem=bem, + meg=True, eeg=False, + mindist=5., + verbose=True) + # Save forward model + fname_fwd = op.join(fpath_fw, subject+"_ses-"+visit+"_task-"+task+"_%s_fwd.fif" % space) + mne.write_forward_solution(fname_fwd, + fwd, + overwrite=True) + # Number of vertices + print(f'\nNumber of vertices: {fwd["src"]}') + # Leadfield size + leadfield = fwd['sol']['data'] + print("\nLeadfield size : %d sensors x %d dipoles" % leadfield.shape) + return fwd + + +# RUN +if __name__ == "__main__": + src = make_source_space(space) + if visit == 'V1': + make_forward_model(src, 'dur') + elif visit == 'V2': + make_forward_model(src, 'vg') + make_forward_model(src, 'replay') diff --git a/source_modelling/S01b_forward_model_template.py b/source_modelling/S01b_forward_model_template.py new file mode 100644 index 0000000..cc009ef --- /dev/null +++ b/source_modelling/S01b_forward_model_template.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +""" +================= +S01. Forward model with MRI template +================= + +@author: Oscar Ferrante oscfer88@gmail.com + +""" + +import os +import os.path as op +import argparse + +import mne + + +parser=argparse.ArgumentParser() +parser.add_argument('--sub', + type=str, + default='SA101', + help='site_id + subject_id (e.g. "SA101")') +parser.add_argument('--visit', + type=str, + default='V1', + help='visit_id (e.g. "V1")') +parser.add_argument('--space', + type=str, + default='surface', + help='source space ("surface" or "volume")') +parser.add_argument('--bids_root', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids', + help='Path to the BIDS root directory') +parser.add_argument('--fs_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/fs', + help='Path to the FreeSurfer directory') +parser.add_argument('--coreg_path', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/coreg', + help='Path to the coreg (derivative) directory') +parser.add_argument('--out_fw', + type=str, + default='/mnt/beegfs/XNAT/COGITATE/MEG/phase_2/processed/bids/derivatives/forward', + help='Path to the forward (derivative) directory') +opt=parser.parse_args() + + +# Set params +subject = "sub-"+opt.sub +visit = opt.visit +space = opt.space + +subjects_dir = opt.fs_path +fname_coreg = op.join(opt.coreg_path, subject, "ses-"+visit, "meg") + +fpath_fw = op.join(opt.out_fw, subject, "ses-"+visit, "meg") +if not op.exists(fpath_fw): + os.makedirs(fpath_fw) + +# fpath_fig = op.join(fpath_fw, "figures") +# if not op.exists(fpath_fig): +# os.makedirs(fpath_fig) + +def make_forward_model_from_template(task): + + # Set path to raw FIF + fname_raw = op.join(opt.bids_root, subject, "ses-"+visit, "meg", subject+"_ses-"+visit+"_task-"+task+"_run-01_meg.fif") + + # Set path to template files: + subj = 'fsaverage' + trans = 'fsaverage' + if space == 'surface': + src = op.join(subjects_dir, subj, 'bem', 'fsaverage-ico-5-src.fif') + bem = op.join(subjects_dir, subj, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif') + + # Load raw + raw = mne.io.read_raw(fname_raw, preload=True) + + # # Check that the locations of sensors is correct with respect to MRI + # mne.viz.plot_alignment( + # raw.info, src=src, trans=trans, + # subjects_dir=subjects_dir, + # show_axes=True, mri_fiducials=True, dig='fiducials') + + # Setup source space and compute forward + fwd = mne.make_forward_solution(raw.info, + trans=trans, + src=src, + bem=bem, + meg=True, eeg=False, + mindist=5., + verbose=True) + + # Save forward model + fname_fwd = op.join(fpath_fw, subject+"_ses-"+visit+"_task-"+task+"_%s_fwd.fif" % space) + mne.write_forward_solution(fname_fwd, + fwd, + overwrite=True) + + +# RUN +if __name__ == "__main__": + if visit == 'V1': + make_forward_model_from_template('dur') + elif visit == 'V2': + make_forward_model_from_template('vg') + make_forward_model_from_template('replay')