Progress Bar?
#226
-
Does anyone have any advice on how to display progress as samples are being drawn? I'm currently using the inference loop below (which I copy-pasted from the example notebooks) and would ideally like to have a progress bar that goes from 1 to num_samples during the scan. def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states |
Beta Was this translation helpful? Give feedback.
Answered by
rlouf
Jun 24, 2022
Replies: 1 comment
-
See discussion in #219 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
rlouf
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See discussion in #219