diff --git a/swirl_dynamics/projects/probabilistic_diffusion/colabs/guided_diffusion_demo.ipynb b/swirl_dynamics/projects/probabilistic_diffusion/colabs/guided_diffusion_demo.ipynb index 014fa9f..0dfe0e2 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/colabs/guided_diffusion_demo.ipynb +++ b/swirl_dynamics/projects/probabilistic_diffusion/colabs/guided_diffusion_demo.ipynb @@ -143,6 +143,7 @@ }, "outputs": [], "source": [ + "from clu import metric_writers\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", @@ -289,6 +290,7 @@ " trainer=trainer,\n", " workdir=workdir,\n", " total_train_steps=num_train_steps,\n", + " metric_writer=metric_writers.create_default_writer(workdir, asynchronous=False),\n", " metric_aggregation_steps=20,\n", " eval_dataloader=eval_dataloader,\n", " eval_every_steps = 1000,\n", @@ -656,6 +658,10 @@ "accelerator": "GPU", "colab": { "gpuType": "T4", + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, "private_outputs": true, "provenance": [ {