Regarding train_step in NNX #4069
-
Is this the correct way to write a train_step function in NNX? @nnx.jit
def train_step(model, x, y, optimizer):
def loss_fn(model, x, y):
logits = model(x)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=y).mean()
return logits, loss
grad_fn = nnx.grad(loss_fn, has_aux=True)
(logits, loss), grads = grad_fn(model, x, y)
optimizer.update(grads)
return loss
Here Unfortunately the flax documentation I'm following for NNX is not very detailed yet. Edit: nnx.grad expects a scaler as output, but that is not possible when when are training a model in batches. It is very confusing to me. Also, what is the difference between edit 2: Okay I got it, loss is a scalar so, the grad function is expecting that... |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hey! Check MNIST Tutorial, I think you want:
|
Beta Was this translation helpful? Give feedback.
Hey! Check MNIST Tutorial, I think you want: