You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We've been successfully using keypoint-moseq on several datasets, but we are encountering a consistent issue with larger datasets, where the kernel crashes during model initialization.
Steps Taken to Isolate the Issue:
Tested with smaller datasets: the model initializes and runs fine.
Simulated larger dataset: we artificially inflated previously working datasets to match the size of the problematic dataset, and the kernel crash occurred, confirming the issue is size-related and not specific to the data itself.
Isolated problematic function: we traced the issue to init_states in jax_moseq/models/arhmm/initialize.py, particularly in the function resample_discrete_stateseqs.
Behavior Observed:
Without @jax.jit on init_states: the kernel dies after processing log_likelihoods correctly but before returning it to init_states. We added print statements for debugging, and the crash occurs between the successful processing and before the value is returned to init_states.
With @jax.jit: the issue is deferred to the return of z. Again, z is processed with the correct array dimensions (as confirmed by logs), but the kernel crashes upon returning the value to init_states. Not returning the value prevents the crash, and init_states continues to execute normally right after, which suggests that the crash happens during the return.
JAX Debugging: for working datasets, jax.debug.print successfully outputs e.g. array shapes, but for datasets that cause the crash, jax.debug.print produces no output, even though it effectively goes to the following steps.
Attempts to Work Around the Issue:
Manual batching and vmap/lax batching: we tried manually processing z in smaller batches or with vmap/lax to reduce the size of the returned array. However, the kernel still crashes, now when init_states returns z.
System Monitoring: we did not observe any VRAM or RAM saturation, so memory exhaustion does not seem to be the cause.
We would be grateful for any workaround for that step, even if that translates to a slower initialization!
Thank you for your help!
The text was updated successfully, but these errors were encountered:
Thanks for the thorough report! Given that the crash is strictly a function of dataset size, it does point strongly in the direction of VRAM saturation. What OS are you using? I have found that OOM errors on Windows cause the kernel to crash without an explicit OOM error.
We are on Windows indeed, and that does look a lot like what we get if the set_mixed_map_iters value for the fitting part is too small - here we only checked the resources visually and with frequent nvidia-smi, but there was no obvious trend or peak detectable (perhaps it's really brief when attempting to return the values, I wouldn't 100% exclude it).
Hello,
We've been successfully using keypoint-moseq on several datasets, but we are encountering a consistent issue with larger datasets, where the kernel crashes during model initialization.
Steps Taken to Isolate the Issue:
Behavior Observed:
Attempts to Work Around the Issue:
We would be grateful for any workaround for that step, even if that translates to a slower initialization!
Thank you for your help!
The text was updated successfully, but these errors were encountered: