Skip to content

Commit

Permalink
error bar now works
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Oct 20, 2023
1 parent 3450e55 commit e15e053
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 68 deletions.
96 changes: 44 additions & 52 deletions code/pages/3_Population dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,17 +300,18 @@ def plot_beta_auto_corr(ds, model, align_tos, paras,
def _get_psth(psth_name, align_to, psth_grouped_by, select_units):
t_range = {f't_to_{align_to}': slice(*plot_settings[align_to]['win'])}

mean_psth = ds_psth[psth_name].sel(stat='mean',
unit_ind=select_units,
**t_range).values
psth_mean, psth_sem = ds_psth[psth_name].sel(stat=['mean', 'sem'],
unit_ind=select_units,
**t_range).values

ts = ds_psth[f't_to_{align_to}'].sel(**t_range).values
group_name = ds_psth[f'psth_groups_{psth_grouped_by}'].values

plot_spec = ds_psth[f'psth_setting_plot_spec_{psth_grouped_by}'].values
if 'reward' in psth_grouped_by: # Bug fix
plot_spec = plot_spec[[1, 0, 3, 2]]

return [mean_psth, ts, group_name, plot_spec]
return [psth_mean, psth_sem, ts, group_name, plot_spec]

@st.cache_data(ttl=60*60*24)
def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units):
Expand All @@ -327,46 +328,29 @@ def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units):
return aver_betas / np.sqrt(np.sum(aver_betas**2))


def compute_psth_proj_on_CD(psth, coding_direction, if_error_bar):
def compute_psth_proj_on_CD(psth, psth_sem, coding_direction, if_error_bar):
# Handle PSTHs with some nan values
psth_reshaped = psth.reshape(psth.shape[0], -1)
psth_sem_reshaped = psth_sem.reshape(psth_sem.shape[0], -1)
nan_idx = np.any(np.isnan(psth_reshaped), axis=1)

psth_reshaped_valid = psth_reshaped[~nan_idx]
psth_sem_reshaped_valid = psth_sem_reshaped[~nan_idx]

coding_direction_valid = coding_direction[~nan_idx] # Remove nan units (temporary fix)
coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize

# Do projection
psth_proj = (psth_reshaped_valid.T @ coding_direction_valid).reshape(psth.shape[1:])

if not if_error_bar:
return psth_proj, None, None

N_bootstrap = 100
return psth_proj, None

# Do bootstrap
bootstrap_results = []
for i in range(N_bootstrap):
# Randomly select indices with replacement
indices = np.random.choice(psth_reshaped_valid.shape[0],
size=psth_reshaped_valid.shape[0],
replace=True)
# Compute 95% CI from psth_sem (assuming uncertainties all come from psth, not betas; also, I didn't take care of the trial number, since Var = SEM^2 * trial_num)
psth_proj_sem = np.sqrt((psth_sem_reshaped_valid.T**2 @ coding_direction_valid**2)).reshape(psth_sem.shape[1:])
psth_proj_95CI = 1.96 * psth_proj_sem

psth_reshaped_valid = psth_reshaped_valid[indices, :]
coding_direction_valid = coding_direction_valid[indices]
coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize

# Compute the product and store the result
bootstrap_results.append(psth_reshaped_valid.T @ coding_direction_valid)
progress_bar.progress(i/N_bootstrap, text=f'{i/N_bootstrap:.0%}')


bootstrap_results = np.array(bootstrap_results)

# Compute 95% confidence intervals
lower_bound = np.percentile(bootstrap_results, 2.5, axis=0).reshape(psth.shape[1:])
upper_bound = np.percentile(bootstrap_results, 97.5, axis=0).reshape(psth.shape[1:])

return psth_proj, lower_bound, upper_bound
return psth_proj, psth_proj_95CI


def plot_psth_proj_on_CDs(
Expand Down Expand Up @@ -440,16 +424,17 @@ def plot_psth_proj_on_CDs(
f'''{psth_align_mapping[psth_align_to]}_''' +\
f'''grouped_by_{psth_grouped_by}'''

psth, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name=psth_name,
align_to=psth_align_to,
psth_grouped_by=psth_grouped_by,
select_units=unit_ind_filtered.values,
)
psth, psth_sem, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name=psth_name,
align_to=psth_align_to,
psth_grouped_by=psth_grouped_by,
select_units=unit_ind_filtered.values,
)

# Compute projection
psth_proj, lower_bound, upper_bound = compute_psth_proj_on_CD(psth=psth,
coding_direction=coding_direction,
if_error_bar=if_error_bar)
psth_proj, psth_proj_95CI = compute_psth_proj_on_CD(psth=psth,
psth_sem=psth_sem,
coding_direction=coding_direction,
if_error_bar=if_error_bar)

# Do plotting
fig = go.Figure()
Expand All @@ -464,14 +449,24 @@ def plot_psth_proj_on_CDs(
),
)
else:
# Hack of the dash and line width
line_spec = eval(psth_plot_specs[i_group])
if 'line_dash' in line_spec:
line_dash = line_spec['line_dash']
line_width = 2 if line_dash == 'dot' else 3
else:
line_dash = 'solid'
line_width = 3

add_plotly_errorbar(x=pd.Series(psth_t),
y=pd.Series(psth_proj[i_group, :]),
err=pd.Series(((upper_bound - lower_bound) / 2)[i_group, :]),
# **eval(psth_plot_specs[i_group]),
err=pd.Series(psth_proj_95CI[i_group, :]),
color=eval(psth_plot_specs[i_group])['marker_color'],
line_dash=line_dash,
line_width=line_width,
name=psth_group_names[i_group],
mode='lines',
fig=fig, alpha=0.4,
fig=fig, alpha=0.2,
)

fig.update_layout(
Expand Down Expand Up @@ -502,12 +497,12 @@ def plot_psth_proj_on_CDs(


# Increase line width for all scatter traces
for trace in fig.data:
if trace.type == 'scatter' and 'line' in dir(trace):
if trace.line.width is None:
trace.line.width = 2
else:
trace.line.width = trace.line.width + 1
# for trace in fig.data:
# if trace.type == 'scatter' and 'line' in dir(trace):
# if trace.line.width is None:
# trace.line.width = 2
# else:
# trace.line.width = trace.line.width + 1

with cols[i+1]:
st.plotly_chart(fig,
Expand Down Expand Up @@ -673,10 +668,7 @@ def plot_psth_proj_on_CDs(

available_paras_this_model = models[selected_model]
# if_combine_araes = cols[0].checkbox('Combine areas', True)
if_error_bar = cols[0].checkbox('Bootstrap CI 95%', False)
if if_error_bar:
progress_bar = st.columns([1, 15])[0].progress(0, text='0%')

if_error_bar = cols[0].checkbox('Show 95% CI', True)

selected_paras = cols[1].multiselect('Coding directions',
coding_direction_beta_aver_epoch.keys(),
Expand Down
38 changes: 22 additions & 16 deletions code/util/plotly_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,6 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='',
err = err[valid_y]
err[~err.notna()] = 0

fig.add_trace(go.Scattergl(
x=x,
y=y,
# error_y=dict(type='data',
# symmetric=True,
# array=tuning_sem),
name=name,
legendgroup=legend_group,
mode="markers+lines" if mode is None else mode,
marker_color=color,
opacity=1,
**kwargs,
),
**subplot_specs)

fig.add_trace(go.Scatter(
# name='Upper Bound',
x=x,
Expand All @@ -41,6 +26,12 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='',
),
**subplot_specs)

color_in_rgba = plotly.colors.convert_colors_to_same_type(color)[0][0].split("(")[-1][:-1]

if color_in_rgba.count(',') == 3: # Already have the alpha element, remove the alpha.
color_in_rgba = ', '.join(color_in_rgba.split(', ')[:-1])
fillcolor = f'rgba({color_in_rgba}, {alpha})'

fig.add_trace(go.Scatter(
# name='Upper Bound',
x=x,
Expand All @@ -49,9 +40,24 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='',
marker=dict(color=color),
line=dict(width=0),
fill='tonexty',
fillcolor=f'rgba({plotly.colors.convert_colors_to_same_type(color)[0][0].split("(")[-1][:-1]}, {alpha})',
fillcolor=fillcolor,
legendgroup=legend_group,
showlegend=False,
hoverinfo='skip'
),
**subplot_specs)

fig.add_trace(go.Scatter(
x=x,
y=y,
# error_y=dict(type='data',
# symmetric=True,
# array=tuning_sem),
name=name,
legendgroup=legend_group,
mode="markers+lines" if mode is None else mode,
marker_color=color,
opacity=1,
**kwargs,
),
**subplot_specs)

0 comments on commit e15e053

Please sign in to comment.