Skip to content

Commit

Permalink
Update bwm_figs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mschart authored Oct 4, 2024
1 parent 512b3cf commit dc686ef
Showing 1 changed file with 146 additions and 13 deletions.
159 changes: 146 additions & 13 deletions brainwidemap/meta/bwm_figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
from matplotlib.patches import Rectangle
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import matplotlib.ticker as tck
import matplotlib.colors as mcolors

from brainwidemap import download_aggregate_tables, bwm_units
from brainwidemap.encoding.design import generate_design
from brainwidemap.encoding.glm_predict import GLMPredictor, predict
from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids
from brainbox.plot import peri_event_time_histogram
from brainbox.behavior.training import compute_performance
from reproducible_ephys_functions import labs

import neurencoding.linear as lm
from neurencoding.utils import remove_regressors
Expand Down Expand Up @@ -1620,6 +1623,92 @@ def swansons_SI(vari):
fig.savefig(Path(imgs_pth, 'si', f'mannwhitney_SI_{vari}.svg'))


'''
#####
Behavioral SI figure
#####
'''

def perf_scatter(rerun=False):

'''
two scatter plots, a point a mouse from the BWM set,
left panel (x, y):
(# bias training sessions, # pre-bias training sessions
right panel (x, y):
(# bias training sessions, % trials correct during ephys)
'''

pth_ = Path(one.cache_dir, 'bwm_res', 'bwm_figs_data',
'training_perf.pqt')

if (not pth_.is_file() or rerun):

# get all subjects
eids = bwm_units(one)['eid'].unique()
pths = one.eid2path(eids)
subs = np.unique([str(x).split('/')[-3] for x in pths])

# get lab colors and shorter names (repro paper standards)
rr = labs()
sub_labs = dict(zip([str(x).split('/')[-3] for x in pths],
[rr[1][str(x).split('/')[-5]] for x in pths]))

sub_cols = {sub: rr[-1][sub_labs[sub]] for sub in subs}

# fill dataframe
r = []
columns = ['subj', '#sess biasedChoiceWorld', '#sess trainingChoiceWorld',
'perf ephysChoiceWorld', 'lab', 'lab_color']

for sub in subs:
trials = one.load_aggregate('subjects', sub,
'_ibl_subjectTrials.table')
nbiased = trials[trials['task_protocol'
] == 'biasedChoiceWorld']['session'].nunique()
nunbiased = trials[trials['task_protocol'
] == 'trainingChoiceWorld']['session'].nunique()
perf_ephys = np.mean(compute_performance(
trials[trials['task_protocol'] == 'ephysChoiceWorld'])[0])
r.append([sub, nbiased, nunbiased,
perf_ephys, sub_labs[sub], sub_cols[sub]])

df = pd.DataFrame(r, columns=columns)
df.to_parquet(pth_)

df = pd.read_parquet(pth_)

# convert to hex for seaborn
df['lab_color'] = df['lab_color'].apply(lambda x: mcolors.to_hex(x))

# Create a dictionary to map lab to its color
lab_palette = dict(zip(df['lab'], df['lab_color']))

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# Left scatter plot: (#sess biasedChoiceWorld, #sess trainingChoiceWorld)
sns.scatterplot(ax=axs[0], data=df, x='#sess biasedChoiceWorld',
y='#sess trainingChoiceWorld', hue='lab',
palette=lab_palette, legend=True)

axs[0].set_xlabel('# biasedChoiceWorld sessions')
axs[0].set_ylabel('# trainingChoiceWorld sessions')
axs[0].set_title('Sessions: biased vs unbiased')

# Right scatter plot: (#sess biasedChoiceWorld, perf ephysChoiceWorld)
sns.scatterplot(ax=axs[1], data=df, x='#sess biasedChoiceWorld',
y='perf ephysChoiceWorld', hue='lab',
palette=lab_palette, legend=False)

axs[1].set_xlabel('# biasedChoiceWorld sessions')
axs[1].set_ylabel('% Trials Correct (ephys)')
axs[1].set_title('Bias Sessions vs Performance')

# Adjust layout and display plot
plt.tight_layout()
plt.show()


'''
#####
decoding
Expand Down Expand Up @@ -1797,8 +1886,8 @@ def dec_scatter(variable,fig=None, ax=None):
variable in [choice, fback]
'''

red = (255/255, 48/255, 23/255)
blue = (34/255,77/255,169/255)
red = red_right
blue = blue_left

alone = False
if not fig:
Expand Down Expand Up @@ -2245,8 +2334,8 @@ def plot_twocond(
error_bars="sem",
ax=ax[i],
smoothing=0.01,
pethline_kwargs={"color": "blue", "linewidth": 2},
errbar_kwargs={"color": "blue", "alpha": 0.5},
pethline_kwargs={"color": blue_left, "linewidth": 2},
errbar_kwargs={"color": blue_left, "alpha": 0.5},
)
oldticks.extend(ax[i].get_yticks())
peri_event_time_histogram(
Expand All @@ -2260,15 +2349,15 @@ def plot_twocond(
error_bars="sem",
ax=ax[i],
smoothing=0.01,
pethline_kwargs={"color": "red", "linewidth": 2},
errbar_kwargs={"color": "red", "alpha": 0.5},
pethline_kwargs={"color": red_right, "linewidth": 2},
errbar_kwargs={"color": red_right, "alpha": 0.5},
)
oldticks.extend(ax[i].get_yticks())
pred1 = cond1pred if not rem_regressor else nrcond1pred
pred2 = cond2pred if not rem_regressor else nrcond2pred
ax[i].step(x, pred1, color="darkblue", linewidth=2)
ax[i].step(x, pred1, color="skyblue", linewidth=2)
oldticks.extend(ax[i].get_yticks())
ax[i].step(x, pred2, color="darkred", linewidth=2)
ax[i].step(x, pred2, color="#F08080", linewidth=2)
oldticks.extend(ax[i].get_yticks())
ax[i].set_ylim([0, np.max(oldticks) * 1.1])
return fig, ax, sspkt, sspkclu, stdf
Expand Down Expand Up @@ -2344,7 +2433,7 @@ def get_example_results():
"choice": (
"671c7ea7-6726-4fbe-adeb-f89c2c8e489b",
"04c9890f-2276-4c20-854f-305ff5c9b6cf",
130, # was 143
143, # was 143
"GRN",
0.000992895, # drsq
"firstMovement_times",
Expand Down Expand Up @@ -2434,7 +2523,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None,

# custom legend
all_lines = axs[1].get_lines()
legend_labels = [reg2, reg1, 'model', 'model']
legend_labels = [reg1, reg2, 'model', 'model']
axs[1].legend(all_lines, legend_labels, loc='upper right',
bbox_to_anchor=(1.2, 1.3), fontsize=f_size_s,
frameon=False)
Expand All @@ -2448,7 +2537,7 @@ def ecoding_raster_lines(variable, clu_id0=None, axs=None,
stdf[aligntime],
trial_idx,
dividers,
["b", "r"],
[blue_left, red_right],
[reg1, reg2],
pre_time=t_before,
post_time=t_after,
Expand Down Expand Up @@ -2513,8 +2602,50 @@ def encoding_wheel_boxen(ax=None, fig=None):
if alone:
fig.tight_layout()
fig.savefig(Path(imgs_pth, 'speed',
'glm_boxen.svg'))

'glm_boxen.svg'))


def encoding_wheel_2d_density(ax=None, fig=None):
# Load data and configurations
d = {}
fs = {'speed': 'GLMs_wheel_speed.pkl',
'velocity': 'GLMs_wheel_velocity.pkl'}

for v in fs:
d[v] = pd.read_pickle(
Path(enc_pth, fs[v]))['mean_fit_results']["wheel"].to_frame()

# Join data
joinwheel = d['speed'].join(d['velocity'], how="inner",
rsuffix="_velocity", lsuffix="_speed")

# Extract the columns to plot
x = joinwheel.iloc[:, 0] # wheel_speed values
y = joinwheel.iloc[:, 1] # wheel_velocity values

# Check if we are plotting alone or on provided axis
alone = False
if not ax:
alone = True
fig, ax = plt.subplots(constrained_layout=True, figsize=[5, 5])

# Plot 2D density using seaborn's kdeplot
sns.kdeplot(x=x, y=y, ax=ax, cmap="Blues", fill=True)

# Label axes
ax.set_xlabel('Wheel Speed')
ax.set_ylabel('Wheel Velocity')

# Adjust axis visibility
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_aspect('equal')
# Save plot if alone
if alone:
fig.tight_layout()
fig.savefig(Path(imgs_pth, 'speed', 'glm_2d_density.svg'))



'''
##########
Expand Down Expand Up @@ -2853,6 +2984,8 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc',
color=palette[reg],
label=f"{reg} {d[reg]['nclus']}")

axs[k].set_xscale('log')

# put region labels
y = yy[-1]
x = xx[-1]
Expand Down

0 comments on commit dc686ef

Please sign in to comment.