You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[15], line 17
15 # stage 1: fit the model with AR only
16 model = kpms.update_hypparams(model, kappa=kappa)
---> 17 model = kpms.fit_model(
18 model,
19 data,
20 metadata,
21 project_dir,
22 model_name,
23 ar_only=True,
24 num_iters=num_ar_iters,
25 save_every_n_iters=25,
26 parallel_message_passing=False
27 )[0];
29 # stage 2: fit the full model
30 model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:272](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py#line=271), in fit_model(model, data, metadata, project_dir, model_name, num_iters, start_iter, verbose, ar_only, parallel_message_passing, jitter, generate_progress_plots, save_every_n_iters, location_aware, **kwargs)
270 save_hdf5(checkpoint_path, model, f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){iteration}")
271 if generate_progress_plots:
--> 272 plot_progress(
273 model,
274 data,
275 checkpoint_path,
276 iteration,
277 project_dir,
278 model_name,
279 savefig=True,
280 )
282 return model, model_name
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:620](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py#line=619), in plot_progress(model, data, checkpoint_path, iteration, project_dir, model_name, path, savefig, fig_size, window_size, min_frequency, min_histogram_length)
618 z = np.array(f[f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){i}/states/z"])
619 sample_state_history.append(z[batch_ix, start : start + window_size])
--> 620 median_durations.append(np.median(get_durations(z, mask)))
622 axs[2].scatter(saved_iterations, median_durations)
623 axs[2].set_ylim([-1, np.max(median_durations) * 1.1])
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:82](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=81), in get_durations(stateseqs, mask)
80 print(mask)
81 #AM edits
---> 82 stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
83 stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
84 changepoints = np.diff(stateseq_padded).nonzero()[0]
File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:40](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=39), in concatenate_stateseqs(stateseqs, mask)
38 stateseq_flat = np.hstack(stateseqs)
39 elif mask is not None:
---> 40 stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0]
41 else:
42 stateseq_flat = stateseqs.flatten()
IndexError: boolean index did not match indexed array along dimension 0; dimension is 102 but corresponding boolean dimension is 116
This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:
Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.
The text was updated successfully, but these errors were encountered:
Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.
This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:
Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.
The text was updated successfully, but these errors were encountered: