From dc686ef5618b9ce181d6a5cabb70a5a26a10cf16 Mon Sep 17 00:00:00 2001 From: Michael Schartner Date: Fri, 4 Oct 2024 16:21:08 +0100 Subject: [PATCH] Update bwm_figs.py --- brainwidemap/meta/bwm_figs.py | 159 +++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 13 deletions(-) diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index d1c793e..3dd541c 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -29,12 +29,15 @@ from matplotlib.patches import Rectangle from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) import matplotlib.ticker as tck +import matplotlib.colors as mcolors from brainwidemap import download_aggregate_tables, bwm_units from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.glm_predict import GLMPredictor, predict from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids from brainbox.plot import peri_event_time_histogram +from brainbox.behavior.training import compute_performance +from reproducible_ephys_functions import labs import neurencoding.linear as lm from neurencoding.utils import remove_regressors @@ -1620,6 +1623,92 @@ def swansons_SI(vari): fig.savefig(Path(imgs_pth, 'si', f'mannwhitney_SI_{vari}.svg')) +''' +##### +Behavioral SI figure +##### +''' + +def perf_scatter(rerun=False): + + ''' + two scatter plots, a point a mouse from the BWM set, + left panel (x, y): + (# bias training sessions, # pre-bias training sessions + right panel (x, y): + (# bias training sessions, % trials correct during ephys) + ''' + + pth_ = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data', + 'training_perf.pqt') + + if (not pth_.is_file() or rerun): + + # get all subjects + eids = bwm_units(one)['eid'].unique() + pths = one.eid2path(eids) + subs = np.unique([str(x).split('/')[-3] for x in pths]) + + # get lab colors and shorter names (repro paper standards) + rr = labs() + sub_labs = dict(zip([str(x).split('/')[-3] for x in pths], + [rr[1][str(x).split('/')[-5]] for x in pths])) + + sub_cols = {sub: rr[-1][sub_labs[sub]] for sub in subs} + + # fill dataframe + r = [] + columns = ['subj', '#sess biasedChoiceWorld', '#sess trainingChoiceWorld', + 'perf ephysChoiceWorld', 'lab', 'lab_color'] + + for sub in subs: + trials = one.load_aggregate('subjects', sub, + '_ibl_subjectTrials.table') + nbiased = trials[trials['task_protocol' + ] == 'biasedChoiceWorld']['session'].nunique() + nunbiased = trials[trials['task_protocol' + ] == 'trainingChoiceWorld']['session'].nunique() + perf_ephys = np.mean(compute_performance( + trials[trials['task_protocol'] == 'ephysChoiceWorld'])[0]) + r.append([sub, nbiased, nunbiased, + perf_ephys, sub_labs[sub], sub_cols[sub]]) + + df = pd.DataFrame(r, columns=columns) + df.to_parquet(pth_) + + df = pd.read_parquet(pth_) + + # convert to hex for seaborn + df['lab_color'] = df['lab_color'].apply(lambda x: mcolors.to_hex(x)) + + # Create a dictionary to map lab to its color + lab_palette = dict(zip(df['lab'], df['lab_color'])) + + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + + # Left scatter plot: (#sess biasedChoiceWorld, #sess trainingChoiceWorld) + sns.scatterplot(ax=axs[0], data=df, x='#sess biasedChoiceWorld', + y='#sess trainingChoiceWorld', hue='lab', + palette=lab_palette, legend=True) + + axs[0].set_xlabel('# biasedChoiceWorld sessions') + axs[0].set_ylabel('# trainingChoiceWorld sessions') + axs[0].set_title('Sessions: biased vs unbiased') + + # Right scatter plot: (#sess biasedChoiceWorld, perf ephysChoiceWorld) + sns.scatterplot(ax=axs[1], data=df, x='#sess biasedChoiceWorld', + y='perf ephysChoiceWorld', hue='lab', + palette=lab_palette, legend=False) + + axs[1].set_xlabel('# biasedChoiceWorld sessions') + axs[1].set_ylabel('% Trials Correct (ephys)') + axs[1].set_title('Bias Sessions vs Performance') + + # Adjust layout and display plot + plt.tight_layout() + plt.show() + + ''' ##### decoding @@ -1797,8 +1886,8 @@ def dec_scatter(variable,fig=None, ax=None): variable in [choice, fback] ''' - red = (255/255, 48/255, 23/255) - blue = (34/255,77/255,169/255) + red = red_right + blue = blue_left alone = False if not fig: @@ -2245,8 +2334,8 @@ def plot_twocond( error_bars="sem", ax=ax[i], smoothing=0.01, - pethline_kwargs={"color": "blue", "linewidth": 2}, - errbar_kwargs={"color": "blue", "alpha": 0.5}, + pethline_kwargs={"color": blue_left, "linewidth": 2}, + errbar_kwargs={"color": blue_left, "alpha": 0.5}, ) oldticks.extend(ax[i].get_yticks()) peri_event_time_histogram( @@ -2260,15 +2349,15 @@ def plot_twocond( error_bars="sem", ax=ax[i], smoothing=0.01, - pethline_kwargs={"color": "red", "linewidth": 2}, - errbar_kwargs={"color": "red", "alpha": 0.5}, + pethline_kwargs={"color": red_right, "linewidth": 2}, + errbar_kwargs={"color": red_right, "alpha": 0.5}, ) oldticks.extend(ax[i].get_yticks()) pred1 = cond1pred if not rem_regressor else nrcond1pred pred2 = cond2pred if not rem_regressor else nrcond2pred - ax[i].step(x, pred1, color="darkblue", linewidth=2) + ax[i].step(x, pred1, color="skyblue", linewidth=2) oldticks.extend(ax[i].get_yticks()) - ax[i].step(x, pred2, color="darkred", linewidth=2) + ax[i].step(x, pred2, color="#F08080", linewidth=2) oldticks.extend(ax[i].get_yticks()) ax[i].set_ylim([0, np.max(oldticks) * 1.1]) return fig, ax, sspkt, sspkclu, stdf @@ -2344,7 +2433,7 @@ def get_example_results(): "choice": ( "671c7ea7-6726-4fbe-adeb-f89c2c8e489b", "04c9890f-2276-4c20-854f-305ff5c9b6cf", - 130, # was 143 + 143, # was 143 "GRN", 0.000992895, # drsq "firstMovement_times", @@ -2434,7 +2523,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None, # custom legend all_lines = axs[1].get_lines() - legend_labels = [reg2, reg1, 'model', 'model'] + legend_labels = [reg1, reg2, 'model', 'model'] axs[1].legend(all_lines, legend_labels, loc='upper right', bbox_to_anchor=(1.2, 1.3), fontsize=f_size_s, frameon=False) @@ -2448,7 +2537,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None, stdf[aligntime], trial_idx, dividers, - ["b", "r"], + [blue_left, red_right], [reg1, reg2], pre_time=t_before, post_time=t_after, @@ -2513,8 +2602,50 @@ def encoding_wheel_boxen(ax=None, fig=None): if alone: fig.tight_layout() fig.savefig(Path(imgs_pth, 'speed', - 'glm_boxen.svg')) - + 'glm_boxen.svg')) + + +def encoding_wheel_2d_density(ax=None, fig=None): + # Load data and configurations + d = {} + fs = {'speed': 'GLMs_wheel_speed.pkl', + 'velocity': 'GLMs_wheel_velocity.pkl'} + + for v in fs: + d[v] = pd.read_pickle( + Path(enc_pth, fs[v]))['mean_fit_results']["wheel"].to_frame() + + # Join data + joinwheel = d['speed'].join(d['velocity'], how="inner", + rsuffix="_velocity", lsuffix="_speed") + + # Extract the columns to plot + x = joinwheel.iloc[:, 0] # wheel_speed values + y = joinwheel.iloc[:, 1] # wheel_velocity values + + # Check if we are plotting alone or on provided axis + alone = False + if not ax: + alone = True + fig, ax = plt.subplots(constrained_layout=True, figsize=[5, 5]) + + # Plot 2D density using seaborn's kdeplot + sns.kdeplot(x=x, y=y, ax=ax, cmap="Blues", fill=True) + + # Label axes + ax.set_xlabel('Wheel Speed') + ax.set_ylabel('Wheel Velocity') + + # Adjust axis visibility + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_aspect('equal') + # Save plot if alone + if alone: + fig.tight_layout() + fig.savefig(Path(imgs_pth, 'speed', 'glm_2d_density.svg')) + + ''' ########## @@ -2853,6 +2984,8 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc', color=palette[reg], label=f"{reg} {d[reg]['nclus']}") + axs[k].set_xscale('log') + # put region labels y = yy[-1] x = xx[-1]