Skip to content

Commit

Permalink
add unit filter
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Oct 20, 2023
1 parent dc787da commit 9bc225f
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions code/pages/3_Population dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,12 @@ def plot_beta_auto_corr(ds, model, align_tos, paras,
return

@st.cache_data(ttl=60*60*24)
def _get_psth(psth_name, align_to, psth_grouped_by):
def _get_psth(psth_name, align_to, psth_grouped_by, select_units):
t_range = {f't_to_{align_to}': slice(*plot_settings[align_to]['win'])}

mean_psth = ds_psth[psth_name].sel(stat='mean', **t_range).values
mean_psth = ds_psth[psth_name].sel(stat='mean',
unit_ind=select_units,
**t_range).values
ts = ds_psth[f't_to_{align_to}'].sel(**t_range).values
group_name = ds_psth[f'psth_groups_{psth_grouped_by}'].values

Expand All @@ -311,24 +313,56 @@ def _get_psth(psth_name, align_to, psth_grouped_by):
return [mean_psth, ts, group_name, plot_spec]

@st.cache_data(ttl=60*60*24)
def _get_coding_direction(model, para, align_to, beta_aver_epoch):
def _get_coding_direction(model, para, align_to, beta_aver_epoch, select_units):
var_name = f'linear_fit_para_stats_aligned_to_{align_to}'
t_name = f'linear_fit_t_center_aligned_to_{align_to}'

aver_betas = ds_linear_fit_over_time[var_name].sel(model=model,
para=para,
para_stat='beta',
unit_ind=select_units,
**{t_name: slice(*beta_aver_epoch)}
).mean(dim=t_name).values

return aver_betas / np.sqrt(np.sum(aver_betas**2))


def compute_psth_proj_on_CD(psth, coding_direction):
# Handle PSTHs with some nan values
psth_reshaped = psth.reshape(psth.shape[0], -1)
nan_idx = np.any(np.isnan(psth_reshaped), axis=1)
psth_reshaped_valid = psth_reshaped[~nan_idx]
coding_direction_valid = coding_direction[~nan_idx] # Remove nan units (temporary fix)
coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize

# Do projection
psth_proj = (psth_reshaped_valid.T @ coding_direction_valid).reshape(psth.shape[1:])

return psth_proj


def plot_psth_proj_on_CDs(
model, psth_align_to,
paras, psth_grouped_bys,
combine_araes=True,
# combine_araes=True,
):

# Retrieve unit_keys from dataset
df_unit_keys = ds_psth[primary_keys].to_dataframe().reset_index()
df_filtered_unit_with_aoi = st.session_state.df_unit_filtered[primary_keys + ['area_of_interest']]

# Add aoi to df_unit_keys; right join to apply the filtering
# Because ds_psth and ds_linear_fit_over_time share the same unit_ind, we can use the same unit_ind to filter
df_unit_keys_filtered_and_with_aoi = df_unit_keys.merge(df_filtered_unit_with_aoi, on=primary_keys, how='right')
unit_ind_filtered = df_unit_keys_filtered_and_with_aoi.unit_ind

st.markdown(f'##### (N = {len(df_unit_keys_filtered_and_with_aoi)}, '
f'PSTH bin size = {ds_psth.bin_size:.2g} s, '
f'smoothed by half Gaussian kernel with $\sigma$ = {ds_psth.psth_aligned_to_Choice_grouped_by_choice_and_reward.smooth_sigma:.2g} s)')

if len(unit_ind_filtered) == 0:
st.write('No units selected!')
return

for para in paras: # Iterate over rows

Expand Down Expand Up @@ -357,7 +391,9 @@ def plot_psth_proj_on_CDs(
coding_direction = _get_coding_direction(model=model,
para=para,
align_to=coding_direction_align_to,
beta_aver_epoch=[win_start, win_end])
beta_aver_epoch=[win_start, win_end],
select_units=unit_ind_filtered.values,
)

# Plot units contribution to coding direction
fig = px.bar(np.sort(coding_direction))
Expand All @@ -375,18 +411,15 @@ def plot_psth_proj_on_CDs(
f'''{psth_align_mapping[psth_align_to]}_''' +\
f'''grouped_by_{psth_grouped_by}'''

# Compute projection
psth, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name, psth_align_to, psth_grouped_by)
psth, psth_t, psth_group_names, psth_plot_specs = _get_psth(psth_name=psth_name,
align_to=psth_align_to,
psth_grouped_by=psth_grouped_by,
select_units=unit_ind_filtered.values,
)

# Handle PSTHs with some nan values
psth_reshaped = psth.reshape(psth.shape[0], -1)
nan_idx = np.any(np.isnan(psth_reshaped), axis=1)
psth_reshaped_valid = psth_reshaped[~nan_idx]
coding_direction_valid = coding_direction[~nan_idx] # Remove nan units (temporary fix)
coding_direction_valid = coding_direction_valid / np.sqrt(np.sum(coding_direction_valid**2)) # Renormalize
# Compute projection
psth_proj = compute_psth_proj_on_CD(psth, coding_direction)

psth_proj = (psth_reshaped_valid.T @ coding_direction_valid).reshape(psth.shape[1:])

# Do plotting
fig = go.Figure()
for i_group in range(psth_proj.shape[0]):
Expand Down Expand Up @@ -595,7 +628,7 @@ def plot_psth_proj_on_CDs(
)

available_paras_this_model = models[selected_model]
if_combine_araes = cols[0].checkbox('Combine areas', True)
# if_combine_araes = cols[0].checkbox('Combine areas', True)

selected_paras = cols[1].multiselect('Coding directions',
coding_direction_beta_aver_epoch.keys(),
Expand All @@ -619,7 +652,8 @@ def plot_psth_proj_on_CDs(
psth_align_to=psth_align_to,
paras=selected_paras,
psth_grouped_bys=selected_grouped_bys,
combine_araes=if_combine_araes)
# combine_araes=if_combine_araes
)


if if_debug:
Expand Down

0 comments on commit 9bc225f

Please sign in to comment.