From 42fb220ccca1a8f1368c1be49c4389df97512f7a Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Tue, 1 Oct 2024 13:14:32 +0100 Subject: [PATCH 01/10] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index 55865f9..8da2a3f 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -466,6 +466,9 @@ def plot_swansons(variable, fig=None, axs=None): f'{ana}_significant'] == True][ f'{ana}_{dt}'].values + vmax = np.max(scores) + vmin = np.min(scores) + if lat: mask = res[np.bitwise_or( res[f'{ana}_significant'] == False, @@ -482,8 +485,14 @@ def plot_swansons(variable, fig=None, axs=None): else: acronyms = res['region'].values scores = res[f'{ana}_effect'].values - mask = [] + mask = [] + if variable == 'stim': + vmax = np.percentile(scores, 95) + vmin = np.percentile(scores, 5) + else: + vmax = np.max(scores) + vmin = np.min(scores) plot_swanson_vector(acronyms, scores, @@ -499,10 +508,12 @@ def plot_swansons(variable, fig=None, axs=None): annotate= True, annotate_n=5, annotate_order='bottom' if lat else 'top', - fontsize=f_size_s) + fontsize=f_size_s, + vmin=vmin, + vmax=vmax) - clevels = (min(scores), max(scores)) - norm = mpl.colors.Normalize(vmin=clevels[0], vmax=clevels[1]) + + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) cbar = fig.colorbar( mpl.cm.ScalarMappable(norm=norm,cmap=cmap.reversed() if lat else cmap), @@ -2290,8 +2301,8 @@ def get_example_results(): lambda c: c == -1, 0.2, 0.05, - "fmoveL", #fmoveL - "fmoveR", #fmoveR + "fmoveL", #choice left + "fmoveR", #choice right ), "feedback_times": ( "feedbackType", @@ -2309,7 +2320,7 @@ def get_example_results(): "stim": ( 'e0928e11-2b86-4387-a203-80c77fab5d52', # EID '799d899d-c398-4e81-abaf-1ef4b02d5475', # PID - 235, # clu_id, was 235 + 235, # clu_id, was 235 -- online 218 looks good "VISp", # region 0.04540706, # drsq (from 02_fit_sessions.py) "stimOn_times", # Alignset key @@ -3350,13 +3361,13 @@ def ghostscript_compress_pdf(variable, level='/printer'): output_path = Path(imgs_pth, variable, f'n5_main_figure_{variverb[variable]}_revised.pdf') - if variable == 'wheel': + elif variable == 'wheel': input_path = Path(imgs_pth, 'speed', f'n5_main_figure_wheel_revised_raw.pdf') output_path = Path(imgs_pth, 'speed', f'n5_main_figure_wheel_revised.pdf') - if variable == 'manuscript': + elif variable == 'manuscript': input_path = Path('/home/mic/Brainwide_Map_Paper.pdf') output_path = Path('/home/mic/Brainwide_Map_Paper2.pdf') @@ -3389,3 +3400,4 @@ def ghostscript_compress_pdf(variable, level='/printer'): + From 512b3cf68468821238a7429e8b022beb26977ca3 Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Wed, 2 Oct 2024 13:24:33 +0100 Subject: [PATCH 02/10] manifold --> population trajectory --- brainwidemap/meta/bwm_figs.py | 93 +++++++++++++++-------------------- 1 file changed, 41 insertions(+), 52 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index 8da2a3f..d1c793e 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -4,7 +4,7 @@ import math, string from collections import Counter, OrderedDict from functools import reduce -import os +import os, sys import itertools from scipy import stats from statsmodels.stats.multitest import multipletests @@ -47,7 +47,7 @@ ''' This script is used to plot the main result figures of the BWM paper. The raw results from each analysis can be found in bwm_figs_res. -There are 4 analyses: manifold, decoding, glm, single-cell +There are 4 analyses: population trajectory, decoding, glm, single-cell See first function in code block of this script for each analysis type for data format conversion. ''' @@ -76,8 +76,8 @@ dec_pth = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data','decoding') dec_pth.mkdir(parents=True, exist_ok=True) -# manifold results -man_pth = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data','manifold') +# population trajectory results +man_pth = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data','trajectory') man_pth.mkdir(parents=True, exist_ok=True) # encoding results @@ -114,7 +114,7 @@ def pool_results_across_analyses(return_raw=False): 4 different analysis types ['glm','euc', 'mw', 'dec'] variables ['stim', ' choice', 'fback'] - some files need conversion to csv (manifold, glm); + some files need conversion to csv (trajectory, glm); see first functions for in the subsequent sections ''' @@ -166,9 +166,25 @@ def pool_results_across_analyses(return_raw=False): d = {} + for vari in variables: - d[vari] = pd.read_csv(Path(man_pth / f'{vari}_restr.csv'))[[ - 'region','amp_euc_can', 'lat_euc_can','p_euc_can']] + r = [] + variable = vari + '_restr' + columns = ['region','nclus', + 'p_euc_can', 'amp_euc_can', + 'lat_euc_can'] + + dd = np.load(Path(man_pth, f'{variable}.npy'), + allow_pickle=True).flat[0] + + for reg in dd: + + r.append([reg, dd[reg]['nclus'], + dd[reg]['p_euc_can'], + dd[reg]['amp_euc_can'], + dd[reg]['lat_euc_can']]) + + d[vari] = pd.DataFrame(data=r,columns=columns) d[vari]['euclidean_significant'] = d[vari].p_euc_can.apply( lambda x: x Date: Fri, 4 Oct 2024 16:21:08 +0100 Subject: [PATCH 03/10] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 159 +++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 13 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index d1c793e..3dd541c 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -29,12 +29,15 @@ from matplotlib.patches import Rectangle from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) import matplotlib.ticker as tck +import matplotlib.colors as mcolors from brainwidemap import download_aggregate_tables, bwm_units from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.glm_predict import GLMPredictor, predict from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids from brainbox.plot import peri_event_time_histogram +from brainbox.behavior.training import compute_performance +from reproducible_ephys_functions import labs import neurencoding.linear as lm from neurencoding.utils import remove_regressors @@ -1620,6 +1623,92 @@ def swansons_SI(vari): fig.savefig(Path(imgs_pth, 'si', f'mannwhitney_SI_{vari}.svg')) +''' +##### +Behavioral SI figure +##### +''' + +def perf_scatter(rerun=False): + + ''' + two scatter plots, a point a mouse from the BWM set, + left panel (x, y): + (# bias training sessions, # pre-bias training sessions + right panel (x, y): + (# bias training sessions, % trials correct during ephys) + ''' + + pth_ = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data', + 'training_perf.pqt') + + if (not pth_.is_file() or rerun): + + # get all subjects + eids = bwm_units(one)['eid'].unique() + pths = one.eid2path(eids) + subs = np.unique([str(x).split('/')[-3] for x in pths]) + + # get lab colors and shorter names (repro paper standards) + rr = labs() + sub_labs = dict(zip([str(x).split('/')[-3] for x in pths], + [rr[1][str(x).split('/')[-5]] for x in pths])) + + sub_cols = {sub: rr[-1][sub_labs[sub]] for sub in subs} + + # fill dataframe + r = [] + columns = ['subj', '#sess biasedChoiceWorld', '#sess trainingChoiceWorld', + 'perf ephysChoiceWorld', 'lab', 'lab_color'] + + for sub in subs: + trials = one.load_aggregate('subjects', sub, + '_ibl_subjectTrials.table') + nbiased = trials[trials['task_protocol' + ] == 'biasedChoiceWorld']['session'].nunique() + nunbiased = trials[trials['task_protocol' + ] == 'trainingChoiceWorld']['session'].nunique() + perf_ephys = np.mean(compute_performance( + trials[trials['task_protocol'] == 'ephysChoiceWorld'])[0]) + r.append([sub, nbiased, nunbiased, + perf_ephys, sub_labs[sub], sub_cols[sub]]) + + df = pd.DataFrame(r, columns=columns) + df.to_parquet(pth_) + + df = pd.read_parquet(pth_) + + # convert to hex for seaborn + df['lab_color'] = df['lab_color'].apply(lambda x: mcolors.to_hex(x)) + + # Create a dictionary to map lab to its color + lab_palette = dict(zip(df['lab'], df['lab_color'])) + + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + + # Left scatter plot: (#sess biasedChoiceWorld, #sess trainingChoiceWorld) + sns.scatterplot(ax=axs[0], data=df, x='#sess biasedChoiceWorld', + y='#sess trainingChoiceWorld', hue='lab', + palette=lab_palette, legend=True) + + axs[0].set_xlabel('# biasedChoiceWorld sessions') + axs[0].set_ylabel('# trainingChoiceWorld sessions') + axs[0].set_title('Sessions: biased vs unbiased') + + # Right scatter plot: (#sess biasedChoiceWorld, perf ephysChoiceWorld) + sns.scatterplot(ax=axs[1], data=df, x='#sess biasedChoiceWorld', + y='perf ephysChoiceWorld', hue='lab', + palette=lab_palette, legend=False) + + axs[1].set_xlabel('# biasedChoiceWorld sessions') + axs[1].set_ylabel('% Trials Correct (ephys)') + axs[1].set_title('Bias Sessions vs Performance') + + # Adjust layout and display plot + plt.tight_layout() + plt.show() + + ''' ##### decoding @@ -1797,8 +1886,8 @@ def dec_scatter(variable,fig=None, ax=None): variable in [choice, fback] ''' - red = (255/255, 48/255, 23/255) - blue = (34/255,77/255,169/255) + red = red_right + blue = blue_left alone = False if not fig: @@ -2245,8 +2334,8 @@ def plot_twocond( error_bars="sem", ax=ax[i], smoothing=0.01, - pethline_kwargs={"color": "blue", "linewidth": 2}, - errbar_kwargs={"color": "blue", "alpha": 0.5}, + pethline_kwargs={"color": blue_left, "linewidth": 2}, + errbar_kwargs={"color": blue_left, "alpha": 0.5}, ) oldticks.extend(ax[i].get_yticks()) peri_event_time_histogram( @@ -2260,15 +2349,15 @@ def plot_twocond( error_bars="sem", ax=ax[i], smoothing=0.01, - pethline_kwargs={"color": "red", "linewidth": 2}, - errbar_kwargs={"color": "red", "alpha": 0.5}, + pethline_kwargs={"color": red_right, "linewidth": 2}, + errbar_kwargs={"color": red_right, "alpha": 0.5}, ) oldticks.extend(ax[i].get_yticks()) pred1 = cond1pred if not rem_regressor else nrcond1pred pred2 = cond2pred if not rem_regressor else nrcond2pred - ax[i].step(x, pred1, color="darkblue", linewidth=2) + ax[i].step(x, pred1, color="skyblue", linewidth=2) oldticks.extend(ax[i].get_yticks()) - ax[i].step(x, pred2, color="darkred", linewidth=2) + ax[i].step(x, pred2, color="#F08080", linewidth=2) oldticks.extend(ax[i].get_yticks()) ax[i].set_ylim([0, np.max(oldticks) * 1.1]) return fig, ax, sspkt, sspkclu, stdf @@ -2344,7 +2433,7 @@ def get_example_results(): "choice": ( "671c7ea7-6726-4fbe-adeb-f89c2c8e489b", "04c9890f-2276-4c20-854f-305ff5c9b6cf", - 130, # was 143 + 143, # was 143 "GRN", 0.000992895, # drsq "firstMovement_times", @@ -2434,7 +2523,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None, # custom legend all_lines = axs[1].get_lines() - legend_labels = [reg2, reg1, 'model', 'model'] + legend_labels = [reg1, reg2, 'model', 'model'] axs[1].legend(all_lines, legend_labels, loc='upper right', bbox_to_anchor=(1.2, 1.3), fontsize=f_size_s, frameon=False) @@ -2448,7 +2537,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None, stdf[aligntime], trial_idx, dividers, - ["b", "r"], + [blue_left, red_right], [reg1, reg2], pre_time=t_before, post_time=t_after, @@ -2513,8 +2602,50 @@ def encoding_wheel_boxen(ax=None, fig=None): if alone: fig.tight_layout() fig.savefig(Path(imgs_pth, 'speed', - 'glm_boxen.svg')) - + 'glm_boxen.svg')) + + +def encoding_wheel_2d_density(ax=None, fig=None): + # Load data and configurations + d = {} + fs = {'speed': 'GLMs_wheel_speed.pkl', + 'velocity': 'GLMs_wheel_velocity.pkl'} + + for v in fs: + d[v] = pd.read_pickle( + Path(enc_pth, fs[v]))['mean_fit_results']["wheel"].to_frame() + + # Join data + joinwheel = d['speed'].join(d['velocity'], how="inner", + rsuffix="_velocity", lsuffix="_speed") + + # Extract the columns to plot + x = joinwheel.iloc[:, 0] # wheel_speed values + y = joinwheel.iloc[:, 1] # wheel_velocity values + + # Check if we are plotting alone or on provided axis + alone = False + if not ax: + alone = True + fig, ax = plt.subplots(constrained_layout=True, figsize=[5, 5]) + + # Plot 2D density using seaborn's kdeplot + sns.kdeplot(x=x, y=y, ax=ax, cmap="Blues", fill=True) + + # Label axes + ax.set_xlabel('Wheel Speed') + ax.set_ylabel('Wheel Velocity') + + # Adjust axis visibility + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_aspect('equal') + # Save plot if alone + if alone: + fig.tight_layout() + fig.savefig(Path(imgs_pth, 'speed', 'glm_2d_density.svg')) + + ''' ########## @@ -2853,6 +2984,8 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc', color=palette[reg], label=f"{reg} {d[reg]['nclus']}") + axs[k].set_xscale('log') + # put region labels y = yy[-1] x = xx[-1] From 8162fedb9f6e956ea3bd457d4c3bd47ac682e942 Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Sat, 5 Oct 2024 14:16:03 +0100 Subject: [PATCH 04/10] include session number SI figure --- brainwidemap/meta/bwm_figs.py | 118 +++++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index 3dd541c..e1fb85c 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -36,7 +36,6 @@ from brainwidemap.encoding.glm_predict import GLMPredictor, predict from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids from brainbox.plot import peri_event_time_histogram -from brainbox.behavior.training import compute_performance from reproducible_ephys_functions import labs import neurencoding.linear as lm @@ -1639,43 +1638,65 @@ def perf_scatter(rerun=False): (# bias training sessions, % trials correct during ephys) ''' - pth_ = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data', - 'training_perf.pqt') - - if (not pth_.is_file() or rerun): + # Define path to the parquet file + pth_ = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data', 'training_perf.pqt') + + # Only reprocess data if file does not exist or rerun is set to True + if not pth_.is_file() or rerun: - # get all subjects + # Retrieve unique subject IDs eids = bwm_units(one)['eid'].unique() pths = one.eid2path(eids) subs = np.unique([str(x).split('/')[-3] for x in pths]) - # get lab colors and shorter names (repro paper standards) + # Get lab metadata rr = labs() - sub_labs = dict(zip([str(x).split('/')[-3] for x in pths], - [rr[1][str(x).split('/')[-5]] for x in pths])) - - sub_cols = {sub: rr[-1][sub_labs[sub]] for sub in subs} + sub_labs = {str(x).split('/')[-3]: rr[1][str(x).split('/')[-5]] for x in pths} + sub_cols = {sub: rr[-1][lab] for sub, lab in sub_labs.items()} - # fill dataframe + # Initialize results list and column names r = [] - columns = ['subj', '#sess biasedChoiceWorld', '#sess trainingChoiceWorld', - 'perf ephysChoiceWorld', 'lab', 'lab_color'] + columns = ['subj', '#sess biasedChoiceWorld', '#sess trainingChoiceWorld', + 'perf ephysChoiceWorld', 'lab', 'lab_color'] + # Process each subject's data for sub in subs: + # Load trials and training data trials = one.load_aggregate('subjects', sub, - '_ibl_subjectTrials.table') - nbiased = trials[trials['task_protocol' - ] == 'biasedChoiceWorld']['session'].nunique() - nunbiased = trials[trials['task_protocol' - ] == 'trainingChoiceWorld']['session'].nunique() - perf_ephys = np.mean(compute_performance( - trials[trials['task_protocol'] == 'ephysChoiceWorld'])[0]) - r.append([sub, nbiased, nunbiased, - perf_ephys, sub_labs[sub], sub_cols[sub]]) - - df = pd.DataFrame(r, columns=columns) + '_ibl_subjectTrials.table') + training = one.load_aggregate('subjects', sub, + '_ibl_subjectTraining.table') + + # Join and sort by session start time + trials = trials.set_index('session').join( + training.set_index('session')).sort_values( + 'session_start_time', kind='stable') + + # Separate behavior sessions based on task protocol + session_types = { + 'training': ['_iblrig_tasks_trainingChoiceWorld', 'trainingChoiceWorld'], + 'biased': ['_iblrig_tasks_biasedChoiceWorld', 'biasedChoiceWorld'], + 'ephys': ['_iblrig_tasks_ephys', 'ephysChoiceWorld'] + } + + # Create filtered sessions based on task_protocol + filtered_sessions = {k: trials[trials[ + 'task_protocol'].str.startswith(tuple(v))] + for k, v in session_types.items()} + + # Calculate metrics + nbiased = filtered_sessions['biased'].index.nunique() + nunbiased = filtered_sessions['training'].index.nunique() + perf_ephys = np.nanmean((filtered_sessions['ephys'].feedbackType + 1) / 2) + + # Append results for the current subject + r.append([sub, nbiased, nunbiased, perf_ephys, sub_labs[sub], sub_cols[sub]]) + + # Create DataFrame and save to parquet + df = pd.DataFrame(r, columns=columns) df.to_parquet(pth_) + # Read the parquet file into DataFrame df = pd.read_parquet(pth_) # convert to hex for seaborn @@ -1684,30 +1705,53 @@ def perf_scatter(rerun=False): # Create a dictionary to map lab to its color lab_palette = dict(zip(df['lab'], df['lab_color'])) - fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + fig, axs = plt.subplots(1, 2, figsize=([7.99, 3.19])) # Left scatter plot: (#sess biasedChoiceWorld, #sess trainingChoiceWorld) sns.scatterplot(ax=axs[0], data=df, x='#sess biasedChoiceWorld', y='#sess trainingChoiceWorld', hue='lab', - palette=lab_palette, legend=True) + palette=lab_palette, legend=True, s=20) + + legend = axs[0].legend_ + legend.get_frame().set_linewidth(0) # Remove legend box outline + axs[0].legend(loc='upper left', bbox_to_anchor=(1, 1), frameon=False) + + # Drop NaNs for Pearson correlation calculation + valid_left = df[['#sess biasedChoiceWorld', + '#sess trainingChoiceWorld']].dropna() + r_left, p_left = stats.pearsonr(valid_left['#sess biasedChoiceWorld'], + valid_left['#sess trainingChoiceWorld']) + + # Annotate plot with r and p-value + axs[0].text(0.05, 0.95, f'r = {r_left:.3f}\np = {p_left:.3f}', + transform=axs[0].transAxes, verticalalignment='top', + fontsize=f_size) - axs[0].set_xlabel('# biasedChoiceWorld sessions') - axs[0].set_ylabel('# trainingChoiceWorld sessions') - axs[0].set_title('Sessions: biased vs unbiased') + axs[0].set_xlabel('# biased training sessions', fontsize=f_size) + axs[0].set_ylabel('# non-biased training sessions', fontsize=f_size) # Right scatter plot: (#sess biasedChoiceWorld, perf ephysChoiceWorld) sns.scatterplot(ax=axs[1], data=df, x='#sess biasedChoiceWorld', y='perf ephysChoiceWorld', hue='lab', - palette=lab_palette, legend=False) + palette=lab_palette, legend=False, s=20) - axs[1].set_xlabel('# biasedChoiceWorld sessions') - axs[1].set_ylabel('% Trials Correct (ephys)') - axs[1].set_title('Bias Sessions vs Performance') + valid_right = df[['#sess biasedChoiceWorld', + 'perf ephysChoiceWorld']].dropna() + r_right, p_right = stats.pearsonr(valid_right['#sess biasedChoiceWorld'], + valid_right['perf ephysChoiceWorld']) - # Adjust layout and display plot - plt.tight_layout() - plt.show() + # Annotate plot with r and p-value + axs[1].text(0.05, 0.95, f'r = {r_right:.3f}\np = {p_right:.3f}', + transform=axs[1].transAxes, verticalalignment='top', + fontsize=f_size) + axs[1].set_xlabel('# biased training sessions', fontsize=f_size) + axs[1].set_ylabel('% trials correct \n (recording sessions only)', + fontsize=f_size) + + # Adjust layout and display plot + plt.tight_layout() + fig.savefig(Path(imgs_pth, 'si', f'session_numbers_SI_.pdf'), dpi=150) ''' ##### From 4aa3516fd8a1ef7b4299898527d6e0cfa60f715f Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Sat, 5 Oct 2024 14:39:00 +0100 Subject: [PATCH 05/10] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index e1fb85c..b3a8cbe 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -1751,7 +1751,7 @@ def perf_scatter(rerun=False): # Adjust layout and display plot plt.tight_layout() - fig.savefig(Path(imgs_pth, 'si', f'session_numbers_SI_.pdf'), dpi=150) + fig.savefig(Path(imgs_pth, 'si', f'n6_supp_figure_learning_stats.pdf'), dpi=150) ''' ##### From c85263740cbe5d9f32578fd23d878c3f68f03ffd Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Thu, 17 Oct 2024 14:41:58 +0100 Subject: [PATCH 06/10] add male/female S! fig --- brainwidemap/meta/bwm_figs.py | 124 ++++++++++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 14 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index b3a8cbe..c1d227f 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -30,11 +30,13 @@ from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) import matplotlib.ticker as tck import matplotlib.colors as mcolors +import matplotlib.patches as mpatches from brainwidemap import download_aggregate_tables, bwm_units from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.glm_predict import GLMPredictor, predict -from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids +from brainwidemap.encoding.utils import (load_regressors, + single_cluster_raster, find_trial_ids) from brainbox.plot import peri_event_time_histogram from reproducible_ephys_functions import labs @@ -1759,6 +1761,9 @@ def perf_scatter(rerun=False): ##### ''' +dec_d = {'stim': 'stimside', 'choice': 'choice', + 'fback': 'feedback'} + def group_into_regions(): ''' @@ -1811,9 +1816,6 @@ def significance_by_region(group): result['sig_combined'] = result['pval_combined'] < ALPHA_LEVEL return result - dec_d = {'stim': 'stimside', 'choice': 'choice', - 'fback': 'feedback'} - # indicate in file name constraint exx = '' if MIN_TRIALS == 0 else ('_' + str(MIN_TRIALS)) @@ -2218,8 +2220,7 @@ def swansons_SI_dec(vari): fig, ax = plt.subplots(figsize=(5.67, 3.18)) - dec_d = {'stim': 'stimside', 'choice': 'choice', - 'fback': 'feedback'} + pqt_file = os.path.join(dec_pth,f"{dec_d[vari]}_stage2.pqt") df1 = pd.read_parquet(pqt_file) @@ -2283,6 +2284,103 @@ def swansons_SI_dec(vari): +def plot_female_male_repro(): + + ''' + For the 5 repro regions, and all variables (stimulus, choice, feedback), + plot results split by female and male mice with two strips of dots + (blue for male, red for female) for each region. + ''' + repro_regs = ["VISa/am", "CA1", "DG", "LP", "PO"] + + + # subject = one.get_details(eid)['subject'] + # info = one.alyx.rest('subjects', 'read', id=subject) + # sex = info['sex'] # 'M', 'F', or 'U' + + # Load the male/female subject dictionary + ds = np.load(os.path.join(dec_pth, 'male_female.npy'), allow_pickle=True).flat[0] + + # Set up the plot structure (3 rows for each variable, 1 column) + fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8.26, 4.37), + sharey=True, sharex=True) + + for row_idx, vari in enumerate(variables): + pqt_file = os.path.join(dec_pth, f"{dec_d[vari]}_stage2.pqt") + df = pd.read_parquet(pqt_file) + + # Combine 'VISa' and 'VISam' into 'VISa/am' + df['region'] = df['region'].replace({'VISa': 'VISa/am', 'VISam': 'VISa/am'}) + + # Restrict the DataFrame to the 5 repro regions + df = df[df['region'].isin(repro_regs)] + + # Map the sex information to the DataFrame + df['sex'] = df['subject'].map(ds) + + # Separate by sex + female_df = df[df['sex'] == 'F'] + male_df = df[df['sex'] == 'M'] + + # Combine male and female data into a new DataFrame for better comparison in the plot + df_combined = pd.concat([female_df.assign(gender='Female'), male_df.assign(gender='Male')]) + + # Plot in the corresponding row + ax = axes[row_idx] + + # Strip plot for both male (blue) and female (red) side by side for each region + sns.stripplot(data=df_combined, x='region', y='score', hue='gender', + palette={'Male': 'blue', 'Female': 'red'}, + dodge=True, jitter=True, size=3, ax=ax) + + # Add mean null score (white dot) for each gender and region + # Ensure the x-alignment is consistent (no dodge for null scores) + for region in repro_regs: + # Get the x position for this region (for both male and female) + x_female = repro_regs.index(region) - 0.2 # Slight left shift for female (consistent with dodge) + x_male = repro_regs.index(region) + 0.2 # Slight right shift for male (consistent with dodge) + + # Plot null score for females (red) + female_null = female_df[female_df['region'] == region]['median-null'].mean() + ax.scatter(x_female, female_null, color='white', + edgecolor='red', zorder=3, s=30) + + # Plot null score for males (blue) + male_null = male_df[male_df['region'] == region]['median-null'].mean() + ax.scatter(x_male, male_null, color='white', + edgecolor='blue', zorder=3, s=30) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Set labels and titles for each row + ax.set_xlabel("Brain Regions") + ax.set_ylabel(f"{variverb[vari]} \n decoding accuracy") + + # Remove legend from individual plots to keep the layout clean + ax.get_legend().remove() + for i in range(1, len(repro_regs)): + ax.axvline(x=i-0.5, color='grey', linestyle='--', linewidth=1) + + legend_handles = [ + plt.Line2D([0], [0], marker='o', color='white', label='Male', + markerfacecolor='blue', markeredgecolor='blue', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Female', + markerfacecolor='red', markeredgecolor='red', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Null (Male)', + markerfacecolor='white', markeredgecolor='blue', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Null (Female)', + markerfacecolor='white', markeredgecolor='red', markersize=5) + ] + + fig.legend(handles=legend_handles, loc='upper center', + frameon=False, ncols=2).set_draggable(True) + + # Adjust layout + plt.tight_layout() + fig.savefig(Path(imgs_pth, 'si', + f'n6_supp_dec_repro_male_female.pdf'), dpi=150) + ''' ########## @@ -2475,9 +2573,9 @@ def get_example_results(): "stimOn_times", # Alignset key ), "choice": ( - "671c7ea7-6726-4fbe-adeb-f89c2c8e489b", - "04c9890f-2276-4c20-854f-305ff5c9b6cf", - 143, # was 143 + "a7763417-e0d6-4f2a-aa55-e382fd9b5fb8",#"671c7ea7-6726-4fbe-adeb-f89c2c8e489b" + "57c5856a-c7bd-4d0f-87c6-37005b1484aa",#"04c9890f-2276-4c20-854f-305ff5c9b6cf" + 74, # was 143 "GRN", 0.000992895, # drsq "firstMovement_times", @@ -2500,7 +2598,7 @@ def get_example_results(): return targetunits, alignsets, sortlookup -def ecoding_raster_lines(variable, clu_id0=None, axs=None, +def encoding_raster_lines(variable, clu_id0=None, axs=None, frac_tr=3): ''' @@ -3028,8 +3126,6 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc', color=palette[reg], label=f"{reg} {d[reg]['nclus']}") - axs[k].set_xscale('log') - # put region labels y = yy[-1] x = xx[-1] @@ -3282,12 +3378,12 @@ def ax_str(x): # encoding panels if not save_pans: - ecoding_raster_lines(variable,clu_id0= clu_id0, + encoding_raster_lines(variable,clu_id0= clu_id0, axs=[ax_str(x) for x in ['ras', 'enc0', 'enc1']]) else: - ecoding_raster_lines(variable, clu_id0=clu_id0) + encoding_raster_lines(variable, clu_id0=clu_id0) ''' From 3d6431ca795df9ecf0d83097f3c99348b97650b9 Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Mon, 21 Oct 2024 11:40:43 +0100 Subject: [PATCH 07/10] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index c1d227f..bb7533a 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -755,7 +755,7 @@ def plot_slices(variable): fig.savefig(Path(imgs_pth, 'si', f'n6_supp_figure_{variverb[variable]}_raw.svg'), - bbox_inches='tight') + bbox_inches='tight', dpi=200) @@ -2235,6 +2235,12 @@ def swansons_SI_dec(vari): acronyms = np.array(list(res.keys())) scores = np.array(list(res.values())) + # turn regions with zero sig sessions to grey + mask = acronyms[scores == 0] + + acronyms = acronyms[scores != 0] + scores = scores[scores != 0] + # turn fraction into percentage scores = scores * 100 @@ -2250,6 +2256,7 @@ def swansons_SI_dec(vari): empty_color="white", linewidth=lw, mask_color='silver', + mask=mask, annotate= True, annotate_n=8, annotate_order='top', From 90c46fc44b3266018379d59cb76211c30f6200be Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Mon, 21 Oct 2024 14:41:47 +0100 Subject: [PATCH 08/10] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index bb7533a..da5ebb7 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -32,7 +32,7 @@ import matplotlib.colors as mcolors import matplotlib.patches as mpatches -from brainwidemap import download_aggregate_tables, bwm_units +from brainwidemap import download_aggregate_tables, bwm_units, bwm_query from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.glm_predict import GLMPredictor, predict from brainwidemap.encoding.utils import (load_regressors, @@ -780,7 +780,7 @@ def plot_all_swansons(): 'glm_effect': ['Abs. diff. $\\Delta R^2$ (log)',[], ['Encoding', 'General linear model']]} - cmap = 'viridis' + cmap = 'viridis_r' num_columns = len(res_types) @@ -2389,6 +2389,28 @@ def plot_female_male_repro(): f'n6_supp_dec_repro_male_female.pdf'), dpi=150) +def print_age_weight(): + + ''' + for all BWM mice print age and weight and means + ''' + subjects = bwm_query(one)['subject'].unique() + d = {} + for sub in subjects: + print(sub) + info = one.alyx.rest('subjects', 'read', id=sub) + d[sub] = [info['age_weeks'], info['reference_weight']] + + print('mean age [weeks]', np.round(np.mean([d[s][0] for s in d]),2)) + print('median age [weeks]', np.round(np.median([d[s][0] for s in d]),2)) + print('max, min [weeks]', np.max([d[s][0] for s in d]), + np.min([d[s][0] for s in d])) + + print('mean weight [gr]', np.round(np.median([d[s][1] for s in d]),2)) + print('median weight [gr]', np.round(np.mean([d[s][1] for s in d]),2)) + print('max, min [gr]', np.max([d[s][1] for s in d]), + np.min([d[s][1] for s in d])) + ''' ########## encoding (glm) From e76f6981932dbab4aefd2662fc3d2ed1ea186752 Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Tue, 22 Oct 2024 16:11:08 +0100 Subject: [PATCH 09/10] Update bwm_loading.py --- brainwidemap/bwm_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainwidemap/bwm_loading.py b/brainwidemap/bwm_loading.py index 979a6b5..de23d68 100644 --- a/brainwidemap/bwm_loading.py +++ b/brainwidemap/bwm_loading.py @@ -202,7 +202,7 @@ def merge_probes(spikes_list, clusters_list): def load_trials_and_mask( one, eid, min_rt=0.08, max_rt=2., nan_exclude='default', min_trial_len=None, max_trial_len=None, exclude_unbiased=False, exclude_nochoice=False, sess_loader=None, - truncate_to_pass=True, saturation_intervals=None + truncate_to_pass=True, saturation_intervals=None, revision='2024-07-14' ): """ Function to load all trials for a given session and create a mask to exclude all trials that have a reaction time @@ -271,7 +271,7 @@ def load_trials_and_mask( ] if sess_loader is None: - sess_loader = SessionLoader(one=one, eid=eid) + sess_loader = SessionLoader(one=one, eid=eid, revision=revision) if sess_loader.trials.empty: sess_loader.load_trials() From 5bb0b54a158f19e5efc0d3416035a928d60dce0f Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Fri, 1 Nov 2024 14:43:41 +0000 Subject: [PATCH 10/10] Add files via upload --- brainwidemap/meta/repro_bwm_decoding.py | 142 ++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 brainwidemap/meta/repro_bwm_decoding.py diff --git a/brainwidemap/meta/repro_bwm_decoding.py b/brainwidemap/meta/repro_bwm_decoding.py new file mode 100644 index 0000000..b351786 --- /dev/null +++ b/brainwidemap/meta/repro_bwm_decoding.py @@ -0,0 +1,142 @@ +import pandas as pd +import numpy as np +from pathlib import Path + +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Rectangle + +from scipy.stats import kruskal, f_oneway +from statsmodels.stats.multitest import multipletests +from scipy.cluster import hierarchy + +from one.api import ONE +from reproducible_ephys_functions import figure_style, labs +from dmn_bwm import get_allen_info + + +''' +Replottig BWM decoding results for the repro paper, grouped by labs, +testing for systematic lab biases +''' + +# for vari plot +_, b, lab_cols = labs() + +one = ONE() + +dec_d = {'stimside': 'stimside', 'choice': 'choice', + 'feedback': 'feedback', 'wheel-speed': 'wheel-speed'} + +dec_pth = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data', 'decoding') + + +def bwm_scores(nscores=3, tt='stripplot', sb='lab'): + """ + Analyze decoding and encoding scores across regions grouped by labs or animals. + + Parameters: + - nscores: Minimum number of scores for a lab/region to be included. + - ana: Analysis type ('dec' or 'enc'), for decoding or encoding (GLM). + - sb: Sort by 'lab' or 'animals'. + """ + + varis = ['choice', 'stimside', 'feedback', 'wheel-speed'] + regs = ['VISa/am', 'CA1', 'DG', 'LP', 'PO'] + + # Use loaded data paths as in `pool_results_across_analyses` + _, pa = get_allen_info() + + # Pooled data paths + ana = 'dec' + analysis_path = dec_pth + + ps = {} + fig, axs = plt.subplots(nrows=1, ncols=len(varis), sharex=True, sharey=True, figsize=(10.88, 7.03)) + k = 0 + + for vari in varis: + # Load pooled data based on `pool_results_across_analyses` + + data_file = analysis_path / f'{dec_d[vari]}_stage2.pqt' + d = pd.read_parquet(data_file) + pths = one.eid2path(d['eid'].values) + d['lab'] = [b[str(p).split('/')[5]] for p in pths] + d['subject'] = [str(p).split('/')[7] for p in pths] + d['region'] = d['region'].replace(['VISa', 'VISam'], 'VISa/am') + d = d.dropna(subset=['score', 'lab', 'region', 'subject']) + + # Plot logic + if tt == 'mean_std': + reg_stats = d.groupby('region')['score'].agg( + mean_score=np.nanmean, std_score=np.nanstd, count_scores='count' + ).reset_index() + reg_stats = reg_stats[reg_stats['count_scores'] >= nscores] + + x = reg_stats['mean_score'].values + y = reg_stats['std_score'].values + regions = reg_stats['region'].values + cols = [pa[region] for region in regions] + sizes = reg_stats['count_scores'].values + + axs[k].scatter(x, y, color=cols, s=sizes if ana == 'dec' else sizes / 10) + for i, reg in enumerate(regions): + axs[k].annotate(f' {reg}', (x[i], y[i]), fontsize=5, color=cols[i]) + + axs[k].set_title(vari) + axs[k].set_xlabel('mean') + axs[k].set_ylabel('std') + + elif tt == 'stripplot': + filtered_data = d[d['region'].isin(regs)] + labs_counts = filtered_data.groupby([sb, 'region'])['score'].count().reset_index(name='score_count') + valid_labs_regions = labs_counts[labs_counts['score_count'] >= nscores] + filtered_data = pd.merge(filtered_data, valid_labs_regions[[sb, 'region']], on=[sb, 'region']) + + labss = np.unique(filtered_data[sb].values) + palette = {lab: lab_cols[lab] for lab in labss} if sb == 'lab' else None + + sns.stripplot(x='score', y='region', hue=sb, palette=palette, data=filtered_data, jitter=True if sb == 'lab' else False, dodge=True, ax=axs[k], order=regs, size=3) + for i, region in enumerate(regs): + if i == len(regs) - 1: + continue + axs[k].axhline(i + 0.5, color='grey', linestyle='--') + + axs[k].set_title(vari) + if sb == 'lab': + if k != 0: + axs[k].legend([], [], frameon=False) + else: + axs[k].legend(loc='lower left', fontsize=9, bbox_to_anchor=(-0.55, 1.04), ncols=len(labss)).set_draggable(True) + + # ANOVA + labs = np.unique(d[sb].values) + for reg in regs: + scores_by_lab = [d[(d[sb] == lab) & (d['region'] == reg)]['score'].values for lab in labs] + filtered_scores_by_lab = [lab_scores for lab_scores in scores_by_lab if lab_scores.size >= nscores] + + if len(filtered_scores_by_lab) < 2: + continue + + F, p = kruskal(*filtered_scores_by_lab) + ps[f"{vari}_{reg}"] = p + m = np.max(np.concatenate(scores_by_lab)) + + weight = 'bold' if p < 0.05 else 'normal' + if vari == 'wheel-speed': + x = 0.6 + + else: + x = 0.1 + axs[k].text(x, regs.index(reg), f'F={F:.2f}\np={p:.3f}', weight=weight, ha='left', va='center', fontsize=8) + + k += 1 + + if tt == 'stripplot': + p_values_list = list(ps.values()) + _, ps_corrected, _, _ = multipletests(p_values_list, alpha=0.05, method='fdr_by') + corrected_p_values_dict = dict(zip(ps.keys(), ps_corrected)) + for key, value in corrected_p_values_dict.items(): + print(f"{key}: p-value = {value:.3f}") + + fig.subplots_adjust(top=0.922, bottom=0.088, left=0.094, right=0.982, hspace=0.2, wspace=0.211)