Monitoring training logs #1484
Unanswered
patrickvonplaten
asked this question in
Show and tell
Replies: 1 comment
-
Yes the Line 62 in 8611538 flax/examples/imagenet/train.py Line 142 in 8611538 Line 226 in 8611538 Consider renaming the discussion to "Returning multiple values from loss function" or similar to the tip more easily discoverable (since it's not so much about monitoring but more about returning multiple values) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've run into the problem of how to monitor values that are calculated within the loss function since the loss function is expected to only return a single loss values. E.g. imagine you have a trainining loop:
=> Now how is it possible to log other metrics (perplexity) than just the loss? It's not that straight-forward in jax since the loss function is differentiated just from one output. Luckily
jax.value_and_grad(...)
has a special flag for this, calledhas_aux=True
(see: https://jax.readthedocs.io/en/latest/jax.html#jax.grad) which allows us to monitor additional parametersSo the following above lines have to be changed:
return loss
=>return loss, perplexity
grad_fn = jax.value_and_grad(loss_fn)
=>grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
loss, grad = grad_fn(model_weights)
=>(loss, perplexity), grad = grad_fn(model_weights)
A more in-detail example can be found here:
flax/examples/lm1b/train.py
Line 220 in 9350b44
Beta Was this translation helpful? Give feedback.
All reactions