diff --git a/code/pages/3_Population dynamics.py b/code/pages/3_Population dynamics.py index 7ef475f..eb4a3a8 100644 --- a/code/pages/3_Population dynamics.py +++ b/code/pages/3_Population dynamics.py @@ -300,9 +300,10 @@ def plot_beta_auto_corr(ds, model, align_tos, paras, def _get_psth(psth_name, align_to, psth_grouped_by, select_units): t_range = {f't_to_{align_to}': slice(*plot_settings[align_to]['win'])} - mean_psth = ds_psth[psth_name].sel(stat='mean', - unit_ind=select_units, - **t_range).values + psth_mean, psth_sem = ds_psth[psth_name].sel(stat=['mean', 'sem'], + unit_ind=select_units, + **t_range).values + ts = ds_psth[f't_to_{align_to}'].sel(**t_range).values group_name = ds_psth[f'psth_groups_{psth_grouped_by}'].values @@ -310,7 +311,7 @@ def _get_psth(psth_name, align_to, psth_grouped_by, select_units): if 'reward' in psth_grouped_by: # Bug fix plot_spec = plot_spec[[1, 0, 3, 2]] - return [mean_psth, ts, group_name, plot_spec] + return [psth_mean, psth_sem, ts, group_name, plot_spec] @st.cache_data(ttl=60*60*24) def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units): @@ -327,11 +328,15 @@ def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units): return aver_betas / np.sqrt(np.sum(aver_betas**2)) -def compute_psth_proj_on_CD(psth, coding_direction, if_error_bar): +def compute_psth_proj_on_CD(psth, psth_sem, coding_direction, if_error_bar): # Handle PSTHs with some nan values psth_reshaped = psth.reshape(psth.shape[0], -1) + psth_sem_reshaped = psth_sem.reshape(psth_sem.shape[0], -1) nan_idx = np.any(np.isnan(psth_reshaped), axis=1) + psth_reshaped_valid = psth_reshaped[~nan_idx] + psth_sem_reshaped_valid = psth_sem_reshaped[~nan_idx] + coding_direction_valid = coding_direction[~nan_idx] # Remove nan units (temporary fix) coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize @@ -339,34 +344,13 @@ def compute_psth_proj_on_CD(psth, coding_direction, if_error_bar): psth_proj = (psth_reshaped_valid.T @ coding_direction_valid).reshape(psth.shape[1:]) if not if_error_bar: - return psth_proj, None, None - - N_bootstrap = 100 + return psth_proj, None - # Do bootstrap - bootstrap_results = [] - for i in range(N_bootstrap): - # Randomly select indices with replacement - indices = np.random.choice(psth_reshaped_valid.shape[0], - size=psth_reshaped_valid.shape[0], - replace=True) + # Compute 95% CI from psth_sem (assuming uncertainties all come from psth, not betas; also, I didn't take care of the trial number, since Var = SEM^2 * trial_num) + psth_proj_sem = np.sqrt((psth_sem_reshaped_valid.T**2 @ coding_direction_valid**2)).reshape(psth_sem.shape[1:]) + psth_proj_95CI = 1.96 * psth_proj_sem - psth_reshaped_valid = psth_reshaped_valid[indices, :] - coding_direction_valid = coding_direction_valid[indices] - coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize - - # Compute the product and store the result - bootstrap_results.append(psth_reshaped_valid.T @ coding_direction_valid) - progress_bar.progress(i/N_bootstrap, text=f'{i/N_bootstrap:.0%}') - - - bootstrap_results = np.array(bootstrap_results) - - # Compute 95% confidence intervals - lower_bound = np.percentile(bootstrap_results, 2.5, axis=0).reshape(psth.shape[1:]) - upper_bound = np.percentile(bootstrap_results, 97.5, axis=0).reshape(psth.shape[1:]) - - return psth_proj, lower_bound, upper_bound + return psth_proj, psth_proj_95CI def plot_psth_proj_on_CDs( @@ -440,16 +424,17 @@ def plot_psth_proj_on_CDs( f'''{psth_align_mapping[psth_align_to]}_''' +\ f'''grouped_by_{psth_grouped_by}''' - psth, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name=psth_name, - align_to=psth_align_to, - psth_grouped_by=psth_grouped_by, - select_units=unit_ind_filtered.values, - ) + psth, psth_sem, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name=psth_name, + align_to=psth_align_to, + psth_grouped_by=psth_grouped_by, + select_units=unit_ind_filtered.values, + ) # Compute projection - psth_proj, lower_bound, upper_bound = compute_psth_proj_on_CD(psth=psth, - coding_direction=coding_direction, - if_error_bar=if_error_bar) + psth_proj, psth_proj_95CI = compute_psth_proj_on_CD(psth=psth, + psth_sem=psth_sem, + coding_direction=coding_direction, + if_error_bar=if_error_bar) # Do plotting fig = go.Figure() @@ -464,14 +449,24 @@ def plot_psth_proj_on_CDs( ), ) else: + # Hack of the dash and line width + line_spec = eval(psth_plot_specs[i_group]) + if 'line_dash' in line_spec: + line_dash = line_spec['line_dash'] + line_width = 2 if line_dash == 'dot' else 3 + else: + line_dash = 'solid' + line_width = 3 + add_plotly_errorbar(x=pd.Series(psth_t), y=pd.Series(psth_proj[i_group, :]), - err=pd.Series(((upper_bound - lower_bound) / 2)[i_group, :]), - # **eval(psth_plot_specs[i_group]), + err=pd.Series(psth_proj_95CI[i_group, :]), color=eval(psth_plot_specs[i_group])['marker_color'], + line_dash=line_dash, + line_width=line_width, name=psth_group_names[i_group], mode='lines', - fig=fig, alpha=0.4, + fig=fig, alpha=0.2, ) fig.update_layout( @@ -502,12 +497,12 @@ def plot_psth_proj_on_CDs( # Increase line width for all scatter traces - for trace in fig.data: - if trace.type == 'scatter' and 'line' in dir(trace): - if trace.line.width is None: - trace.line.width = 2 - else: - trace.line.width = trace.line.width + 1 + # for trace in fig.data: + # if trace.type == 'scatter' and 'line' in dir(trace): + # if trace.line.width is None: + # trace.line.width = 2 + # else: + # trace.line.width = trace.line.width + 1 with cols[i+1]: st.plotly_chart(fig, @@ -673,10 +668,7 @@ def plot_psth_proj_on_CDs( available_paras_this_model = models[selected_model] # if_combine_araes = cols[0].checkbox('Combine areas', True) - if_error_bar = cols[0].checkbox('Bootstrap CI 95%', False) - if if_error_bar: - progress_bar = st.columns([1, 15])[0].progress(0, text='0%') - + if_error_bar = cols[0].checkbox('Show 95% CI', True) selected_paras = cols[1].multiselect('Coding directions', coding_direction_beta_aver_epoch.keys(), diff --git a/code/util/plotly_util.py b/code/util/plotly_util.py index 75fef94..c23ae77 100644 --- a/code/util/plotly_util.py +++ b/code/util/plotly_util.py @@ -13,21 +13,6 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='', err = err[valid_y] err[~err.notna()] = 0 - fig.add_trace(go.Scattergl( - x=x, - y=y, - # error_y=dict(type='data', - # symmetric=True, - # array=tuning_sem), - name=name, - legendgroup=legend_group, - mode="markers+lines" if mode is None else mode, - marker_color=color, - opacity=1, - **kwargs, - ), - **subplot_specs) - fig.add_trace(go.Scatter( # name='Upper Bound', x=x, @@ -41,6 +26,12 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='', ), **subplot_specs) + color_in_rgba = plotly.colors.convert_colors_to_same_type(color)[0][0].split("(")[-1][:-1] + + if color_in_rgba.count(',') == 3: # Already have the alpha element, remove the alpha. + color_in_rgba = ', '.join(color_in_rgba.split(', ')[:-1]) + fillcolor = f'rgba({color_in_rgba}, {alpha})' + fig.add_trace(go.Scatter( # name='Upper Bound', x=x, @@ -49,9 +40,24 @@ def add_plotly_errorbar(x, y, err, color, fig, alpha=0.2, name='', marker=dict(color=color), line=dict(width=0), fill='tonexty', - fillcolor=f'rgba({plotly.colors.convert_colors_to_same_type(color)[0][0].split("(")[-1][:-1]}, {alpha})', + fillcolor=fillcolor, legendgroup=legend_group, showlegend=False, hoverinfo='skip' ), **subplot_specs) + + fig.add_trace(go.Scatter( + x=x, + y=y, + # error_y=dict(type='data', + # symmetric=True, + # array=tuning_sem), + name=name, + legendgroup=legend_group, + mode="markers+lines" if mode is None else mode, + marker_color=color, + opacity=1, + **kwargs, + ), + **subplot_specs) \ No newline at end of file