Skip to content

Commit

Permalink
Fix introductory CNN example in recent JAX.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 14, 2024
1 parent 40eac53 commit d9b3ffd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,9 @@
" y: Int[Array, \" batch\"],\n",
" ):\n",
" loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)\n",
" updates, opt_state = optim.update(grads, opt_state, model)\n",
" updates, opt_state = optim.update(\n",
" grads, opt_state, eqx.filter(model, eqx.is_array)\n",
" )\n",
" model = eqx.apply_updates(model, updates)\n",
" return model, opt_state, loss_value\n",
"\n",
Expand Down

0 comments on commit d9b3ffd

Please sign in to comment.