Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_duration(z, mask) encountering IndexError when running kappa scan but not when trainig AR-HMM or full model #174

Open
amorsi1 opened this issue Oct 11, 2024 · 0 comments

Comments

@amorsi1
Copy link

amorsi1 commented Oct 11, 2024

Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.

 ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[15], line 17
     15 # stage 1: fit the model with AR only
     16 model = kpms.update_hypparams(model, kappa=kappa)
---> 17 model = kpms.fit_model(
     18     model,
     19     data,
     20     metadata,
     21     project_dir,
     22     model_name,
     23     ar_only=True,
     24     num_iters=num_ar_iters,
     25     save_every_n_iters=25,
     26     parallel_message_passing=False
     27 )[0];
     29 # stage 2: fit the full model
     30 model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:272](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py#line=271), in fit_model(model, data, metadata, project_dir, model_name, num_iters, start_iter, verbose, ar_only, parallel_message_passing, jitter, generate_progress_plots, save_every_n_iters, location_aware, **kwargs)
    270                 save_hdf5(checkpoint_path, model, f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){iteration}")
    271                 if generate_progress_plots:
--> 272                     plot_progress(
    273                         model,
    274                         data,
    275                         checkpoint_path,
    276                         iteration,
    277                         project_dir,
    278                         model_name,
    279                         savefig=True,
    280                     )
    282 return model, model_name

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:620](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py#line=619), in plot_progress(model, data, checkpoint_path, iteration, project_dir, model_name, path, savefig, fig_size, window_size, min_frequency, min_histogram_length)
    618         z = np.array(f[f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){i}/states/z"])
    619         sample_state_history.append(z[batch_ix, start : start + window_size])
--> 620         median_durations.append(np.median(get_durations(z, mask)))
    622 axs[2].scatter(saved_iterations, median_durations)
    623 axs[2].set_ylim([-1, np.max(median_durations) * 1.1])

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:82](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=81), in get_durations(stateseqs, mask)
     80 print(mask)
     81 #AM edits
---> 82 stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
     83 stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
     84 changepoints = np.diff(stateseq_padded).nonzero()[0]

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:40](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=39), in concatenate_stateseqs(stateseqs, mask)
     38     stateseq_flat = np.hstack(stateseqs)
     39 elif mask is not None:
---> 40     stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0]
     41 else:
     42     stateseq_flat = stateseqs.flatten()

IndexError: boolean index did not match indexed array along dimension 0; dimension is 102 but corresponding boolean dimension is 116

This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:

stateseqs len (116, 10027)
[[ 3  3  3 ... 79 79 46]
 [95 95  5 ... 90 90 90]
 [90 90 90 ... 58 58 58]
 ...
 [90 90 90 ... 54 54 54]
 [36 36 36 ... 90 90 90]
 [90 90 90 ... 67 90 90]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]
stateseqs len (102, 10027)
[[72 72 96 ... 91 91 91]
 [72 72 72 ... 52 52 52]
 [52 52 52 ... 23 23 23]
 ...
 [23 23 23 ... 77 77 77]
 [23 23 23 ... 12 12 12]
 [12 12 12 ... 87 87 87]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]

Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant