You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to implement using flax nnx a model that does the backward pass differently in the following way. At each step, it first computes the gradient w.r.t. weights. Then, it updates the weights (synching with any collectives as necessary first using say jax.lax.pmean). Then, it uses the updated weights to compute gradients w.r.t. activations. Then this process repeats for the layer below:
$$
\begin{aligned}
\Delta w \equiv \dfrac{dL}{dw}(w, a) \\
w \leftarrow w - \alpha \Delta w \\
\Delta a \equiv \dfrac{dL}{da}(w, a)
\end{aligned}
$$
The total amount of computation done by this altered procedure is the same, just the order is different. And in particular, because weights are updated during the backward pass, there is no need to collect weight gradients at the end.
But, in order to take advantage of jax's autograd, the typical idiom in flax nnx would be:
Could I somehow hack model and use @custom_vjp for each layer, and then treat the weights as part of the state? And, then I'd have no need to do optimizer.update(grads) because the weights would be updated during the call to grad_fn.
It seems like, in order to make use of jax's autograd mechanisms (at least, the part that chains together the "gradients"), I'd need to use @custom_vjp, but that code needs to be side-effect free, so I'm not sure how that would work.
I'd like to avoid having to hand-code the entire model backward function. Any ideas would be greatly appreciated!
p.s. this is related to my discussion here in which @mattjj outlines how one would hand-code this function (thank you Matt!).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I would like to implement using flax nnx a model that does the backward pass differently in the following way. At each step, it first computes the gradient w.r.t. weights. Then, it updates the weights (synching with any collectives as necessary first using say
jax.lax.pmean
). Then, it uses the updated weights to compute gradients w.r.t. activations. Then this process repeats for the layer below:The total amount of computation done by this altered procedure is the same, just the order is different. And in particular, because weights are updated during the backward pass, there is no need to collect weight gradients at the end.
But, in order to take advantage of jax's autograd, the typical idiom in flax nnx would be:
Could I somehow hack
model
and use@custom_vjp
for each layer, and then treat the weights as part of the state? And, then I'd have no need to dooptimizer.update(grads)
because the weights would be updated during the call tograd_fn
.It seems like, in order to make use of jax's autograd mechanisms (at least, the part that chains together the "gradients"), I'd need to use
@custom_vjp
, but that code needs to be side-effect free, so I'm not sure how that would work.I'd like to avoid having to hand-code the entire model
backward
function. Any ideas would be greatly appreciated!p.s. this is related to my discussion here in which @mattjj outlines how one would hand-code this function (thank you Matt!).
Beta Was this translation helpful? Give feedback.
All reactions