Skip to content

Regarding train_step in NNX #4069

Answered by cgarciae
qnixsynapse asked this question in Q&A
Jul 10, 2024 · 1 comments · 2 replies
Discussion options

You must be logged in to vote

Hey! Check MNIST Tutorial, I think you want:

@nnx.jit
def train_step(model, x, y, optimizer):
    def loss_fn(model, x, y):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=y).mean()        
        return loss, logits # invert order
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) # use value_and_grad
    (loss, logits), grads = grad_fn(model, x, y)
    optimizer.update(grads)
    return loss

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@qnixsynapse
Comment options

@cisprague
Comment options

Answer selected by qnixsynapse
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants