You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a loss function that returns (loss_value, extra_data). Native jax supports this kind of construct with jax.value_and_grad(loss_fn, has_aux=True) (doc). The differentiated function returns ((loss_value, extra_data), grad).
In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use optax.value_and_grad_from_state(loss_fn) (doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms and optax.value_and_grad_from_state don't support auxiliary data.
I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.
The text was updated successfully, but these errors were encountered:
I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.
Cool, I'll prepare a PR once I'm back from vacations in 1-2 weeks.On 10 Sep 2024, at 18:10, Vincent Roulet ***@***.***> wrote:
Hello @ro0mquy,
I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.
Thanks!
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: ***@***.***>
I have a loss function that returns
(loss_value, extra_data)
. Native jax supports this kind of construct withjax.value_and_grad(loss_fn, has_aux=True)
(doc). The differentiated function returns((loss_value, extra_data), grad)
.In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use
optax.value_and_grad_from_state(loss_fn)
(doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms andoptax.value_and_grad_from_state
don't support auxiliary data.I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.
The text was updated successfully, but these errors were encountered: