You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I don't yet understand how to mix general random number generation with the NNX Rngs. Can you provide a good pattern to replace the repeated calls to key, subkey = random.split(key, num=2) etc. that people often do with JAX?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I don't yet understand how to mix general random number generation with the NNX Rngs. Can you provide a good pattern to replace the repeated calls to
key, subkey = random.split(key, num=2)
etc. that people often do with JAX?To be more concrete, lets say I have
Note that this uses the rngs for the nnx, but also draws random numbers and calls functions which are expecting normal jax keys.
Is the above correct? is
rngs.next()
the way to have this reproducible, instead of splitting manually, etc.?Beta Was this translation helpful? Give feedback.
All reactions