Skip to content

Commit

Permalink
WIP: bootstrap errorbar (not working?
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Oct 20, 2023
1 parent 9bc225f commit 3450e55
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 20 deletions.
87 changes: 68 additions & 19 deletions code/pages/3_Population dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units):
return aver_betas / np.sqrt(np.sum(aver_betas**2))


def compute_psth_proj_on_CD(psth, coding_direction):
# Handle PSTHs with some nan values
def compute_psth_proj_on_CD(psth, coding_direction, if_error_bar):
# Handle PSTHs with some nan values
psth_reshaped = psth.reshape(psth.shape[0], -1)
nan_idx = np.any(np.isnan(psth_reshaped), axis=1)
psth_reshaped_valid = psth_reshaped[~nan_idx]
Expand All @@ -338,13 +338,42 @@ def compute_psth_proj_on_CD(psth, coding_direction):
# Do projection
psth_proj = (psth_reshaped_valid.T @ coding_direction_valid).reshape(psth.shape[1:])

return psth_proj
if not if_error_bar:
return psth_proj, None, None

N_bootstrap = 100

# Do bootstrap
bootstrap_results = []
for i in range(N_bootstrap):
# Randomly select indices with replacement
indices = np.random.choice(psth_reshaped_valid.shape[0],
size=psth_reshaped_valid.shape[0],
replace=True)

psth_reshaped_valid = psth_reshaped_valid[indices, :]
coding_direction_valid = coding_direction_valid[indices]
coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize

# Compute the product and store the result
bootstrap_results.append(psth_reshaped_valid.T @ coding_direction_valid)
progress_bar.progress(i/N_bootstrap, text=f'{i/N_bootstrap:.0%}')


bootstrap_results = np.array(bootstrap_results)

# Compute 95% confidence intervals
lower_bound = np.percentile(bootstrap_results, 2.5, axis=0).reshape(psth.shape[1:])
upper_bound = np.percentile(bootstrap_results, 97.5, axis=0).reshape(psth.shape[1:])

return psth_proj, lower_bound, upper_bound


def plot_psth_proj_on_CDs(
model, psth_align_to,
paras, psth_grouped_bys,
# combine_araes=True,
if_error_bar=False,
):

# Retrieve unit_keys from dataset
Expand Down Expand Up @@ -418,33 +447,48 @@ def plot_psth_proj_on_CDs(
)

# Compute projection
psth_proj = compute_psth_proj_on_CD(psth, coding_direction)
psth_proj, lower_bound, upper_bound = compute_psth_proj_on_CD(psth=psth,
coding_direction=coding_direction,
if_error_bar=if_error_bar)

# Do plotting
fig = go.Figure()
for i_group in range(psth_proj.shape[0]):
fig.add_trace(go.Scatter(x=psth_t,
y=psth_proj[i_group, :],
mode='lines',

if not if_error_bar:
fig.add_trace(go.Scatter(x=psth_t,
y=psth_proj[i_group, :],
mode='lines',
name=psth_group_names[i_group],
**eval(psth_plot_specs[i_group]),
),
)
else:
add_plotly_errorbar(x=pd.Series(psth_t),
y=pd.Series(psth_proj[i_group, :]),
err=pd.Series(((upper_bound - lower_bound) / 2)[i_group, :]),
# **eval(psth_plot_specs[i_group]),
color=eval(psth_plot_specs[i_group])['marker_color'],
name=psth_group_names[i_group],
**eval(psth_plot_specs[i_group]),
),
)
mode='lines',
fig=fig, alpha=0.4,
)

fig.update_layout(
font_size=17,
hovermode='closest',
xaxis=dict(title=f'Time to {psth_align_to} (sec)',
title_font_size=20,
tickfont_size=20),
yaxis=dict(visible=False),
legend=dict(
font_size=17,
hovermode='closest',
xaxis=dict(title=f'Time to {psth_align_to} (sec)',
title_font_size=20,
tickfont_size=20),
yaxis=dict(visible=False),
legend=dict(
yanchor="top", y=1.3,
xanchor="left", x=0,
orientation="h",
font_size=15,
),
)
),
)


fig.for_each_xaxis(lambda x: x.update(showgrid=True))

Expand Down Expand Up @@ -629,6 +673,10 @@ def plot_psth_proj_on_CDs(

available_paras_this_model = models[selected_model]
# if_combine_araes = cols[0].checkbox('Combine areas', True)
if_error_bar = cols[0].checkbox('Bootstrap CI 95%', False)
if if_error_bar:
progress_bar = st.columns([1, 15])[0].progress(0, text='0%')


selected_paras = cols[1].multiselect('Coding directions',
coding_direction_beta_aver_epoch.keys(),
Expand All @@ -653,6 +701,7 @@ def plot_psth_proj_on_CDs(
paras=selected_paras,
psth_grouped_bys=selected_grouped_bys,
# combine_araes=if_combine_araes
if_error_bar=if_error_bar,
)


Expand Down
2 changes: 1 addition & 1 deletion code/util/plotly_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='',
mode=None,
legend_group=None, subplot_specs=None, **kwargs):
legend_group=None, subplot_specs={}, **kwargs):
if legend_group is None:
legend_group = f'group_{name}'

Expand Down

0 comments on commit 3450e55

Please sign in to comment.