Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding notebook illustrating schedule_free optimizer integration #1022

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def _recursive_add_annotations_import():
'cifar10_resnet.ipynb',
'adversarial_training.ipynb',
'reduce_on_plateau.ipynb',
'differentially_private_sgd.ipynb'
'differentially_private_sgd.ipynb',
'schedule_free.ipynb'
]

# -- Options for katex ------------------------------------------------------
Expand Down
333 changes: 333 additions & 0 deletions examples/contrib/schedule_free.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "7e-UQLFJE0VK"
},
"source": [
"# Schedule-free optimizer\n",
"\n",
"This notebook illustrates how to incoprorate the [optax.contrib.schedule_free](https://optax.readthedocs.io/en/latest/api/contrib.html#optax.contrib.schedule_free) optimizer in usual pipelines.\n",
"\n",
"The notebook is purely for implementation details purposes not for performance illustration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "C-ZmkG6kFvpT"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"from flax import linen as nn\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import optax.tree_utils as otu\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"tf.config.experimental.set_visible_devices([], \"GPU\")\n",
"print(\"JAX running on\", jax.devices()[0].platform.upper())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qn0EwMT4F5I0"
},
"outputs": [],
"source": [
"# @markdown Total number of epochs to train for:\n",
"N_STEPS = 1000 # @param{type:\"integer\"}\n",
"\n",
"# @markdown Number of samples in each batch:\n",
"BATCH_SIZE = 4 # @param{type:\"integer\"}\n",
"\n",
"# @markdown Frequency to eval loss\n",
"EVAL_EVERY = 50 # @param{type:\"integer\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "51zbr9nJHO9k"
},
"source": [
"## Setup\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bzdx8a3sGAhE"
},
"outputs": [],
"source": [
"# @title Data\n",
"\n",
"def tf_to_numpy(xs):\n",
" return jax.tree_util.tree_map(lambda x: x._numpy(), xs)\n",
"\n",
"\n",
"def get_data():\n",
" (train_loader, test_loader), info = tfds.load(\n",
" \"cifar10\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n",
" )\n",
"\n",
" def augment(image, label):\n",
" \"\"\"Performs data augmentation.\"\"\"\n",
" image = tf.image.resize_with_crop_or_pad(image, 40, 40)\n",
" image = tf.image.random_crop(image, [32, 32, 3])\n",
" image = tf.image.random_flip_left_right(image)\n",
" image = tf.image.random_brightness(image, max_delta=0.2)\n",
" image = tf.image.random_contrast(image, 0.8, 1.2)\n",
" image = tf.image.random_saturation(image, 0.8, 1.2)\n",
" return image, label\n",
"\n",
"\n",
" train_loader = train_loader.repeat().map(augment)\n",
"\n",
" train_loader = train_loader.shuffle(\n",
" buffer_size=10_000, reshuffle_each_iteration=True\n",
" ).batch(BATCH_SIZE, drop_remainder=True)\n",
" train_loader = map(tf_to_numpy, train_loader)\n",
"\n",
" test_loader = test_loader.batch(BATCH_SIZE, drop_remainder=True).repeat().prefetch(10)\n",
" test_loader = map(tf_to_numpy, test_loader)\n",
"\n",
" train_steps_per_epoch = math.ceil(info.splits['train'].num_examples / BATCH_SIZE)\n",
" val_steps_per_epoch = math.ceil(info.splits['test'].num_examples / BATCH_SIZE)\n",
" info = {'train_steps_per_epoch': train_steps_per_epoch, 'val_steps_per_epoch': val_steps_per_epoch}\n",
" return train_loader, test_loader, info"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0sVFLVCGFHm"
},
"outputs": [],
"source": [
"# @title Model\n",
"class CNN(nn.Module):\n",
" \"\"\"A simple CNN model.\"\"\"\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
" x = nn.relu(x)\n",
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
" x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
" x = nn.relu(x)\n",
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
" x = x.reshape((x.shape[0], -1)) # flatten\n",
" x = nn.Dense(features=256)(x)\n",
" x = nn.relu(x)\n",
" x = nn.Dense(features=10)(x)\n",
" return x\n",
"\n",
"net = CNN()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_7ZMIPfNHgPe"
},
"outputs": [],
"source": [
"# @title Train, eval steps\n",
"\n",
"def get_eval_params(params, state):\n",
" sfo_state = otu.tree_get(state, \"ScheduleFreeState\")\n",
" if sfo_state is not None:\n",
" eval_params = optax.contrib.schedule_free_eval_params(sfo_state, params)\n",
" else:\n",
" eval_params = params\n",
" return eval_params\n",
"\n",
"\n",
"def train_obj(params, data):\n",
" inputs, labels = data\n",
" logits = net.apply(params, inputs)\n",
" loss = optax.softmax_cross_entropy_with_integer_labels(\n",
" logits=logits, labels=labels\n",
" ).mean()\n",
" accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n",
" return loss, accuracy\n",
"\n",
"\n",
"def train_step(params, state, data, opt):\n",
" _, grads = jax.value_and_grad(train_obj, has_aux=True)(params, data)\n",
" udpates, state = opt.update(grads, state, params)\n",
" params = optax.apply_updates(params, udpates)\n",
" return params, state\n",
"\n",
"\n",
"train_step = jax.jit(train_step, static_argnames=[\"opt\"])\n",
"\n",
"\n",
"def eval_step(params, state, data):\n",
" eval_params = get_eval_params(params, state)\n",
" loss, accuracy = train_obj(eval_params, data)\n",
" return loss, accuracy\n",
"\n",
"\n",
"eval_step = jax.jit(eval_step)\n",
"\n",
"\n",
"def eval(params, state, dataset, num_batch_per_eval):\n",
" total_loss = 0.0\n",
" total_acc = 0.0\n",
" for _ in range(num_batch_per_eval):\n",
" batch = next(dataset)\n",
" loss, acc = eval_step(params, state, batch)\n",
" total_loss += loss\n",
" total_acc += acc\n",
" return total_loss/num_batch_per_eval, total_acc/num_batch_per_eval\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "maeJY_rSWX9p"
},
"outputs": [],
"source": [
"# @title Train loop\n",
"\n",
"def init_params(input_example, opt):\n",
" key = jax.random.PRNGKey(0)\n",
" params = net.init(key, input_example)\n",
" state = opt.init(params)\n",
" return params, state\n",
"\n",
"def train_loop(params, state, opt, train_loader, test_loader, info_data):\n",
" loss_log = []\n",
" acc_log = []\n",
"\n",
" for step, batch in zip(range(N_STEPS), train_loader):\n",
" params, state = train_step(params, state, batch, opt)\n",
" if (step % EVAL_EVERY) == 0:\n",
" avg_loss, avg_acc = eval(params, state, test_loader, info_data['val_steps_per_epoch'])\n",
" print(f'step: {step}, loss: {avg_loss}, acc: {avg_acc}')\n",
" loss_log.append(avg_loss)\n",
" acc_log.append(avg_acc)\n",
" return loss_log, acc_log\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LSAFeKGBfyBs"
},
"source": [
"## Experiments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gY05LUaSa28-"
},
"outputs": [],
"source": [
"# @title Adam with prefixed schedule\n",
"\n",
"schedule = optax.warmup_cosine_decay_schedule(0., 1e-3, int(N_STEPS/10), decay_steps=N_STEPS)\n",
"opt = optax.adam(learning_rate=schedule)\n",
"\n",
"train_loader, test_loader, info_data = get_data()\n",
"input_exmp = next(iter(train_loader))[0]\n",
"params, state = init_params(input_exmp, opt)\n",
"\n",
"loss_log, acc_log = train_loop(params, state, opt, train_loader, test_loader, info_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nmcmtnkgLieD"
},
"outputs": [],
"source": [
"plt.plot(acc_log)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpRJ6OFUf-vH"
},
"outputs": [],
"source": [
"# @title Schedule-free Adamw\n",
"opt = optax.contrib.schedule_free_adamw(learning_rate=1e-3, warmup_steps=int(N_STEPS/10))\n",
"\n",
"train_loader, test_loader, info_data = get_data()\n",
"input_exmp = next(iter(train_loader))[0]\n",
"params, state = init_params(input_exmp, opt)\n",
"\n",
"loss_log, acc_log = train_loop(params, state, opt, train_loader, test_loader, info_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SNXytKweK0Jl"
},
"outputs": [],
"source": [
"plt.plot(acc_log)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skPVpeOfLrFQ"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
9 changes: 0 additions & 9 deletions examples/lookahead_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,6 @@
"@jax.jit\n",
"def loss_accuracy(params, data):\n",
" \"\"\"Computes loss and accuracy over a mini-batch.\n",
"\n",
" Args:\n",
" params: parameters of the model.\n",
" bn_params: state of the model.\n",
" data: tuple of (inputs, labels).\n",
" is_training: if true, uses train mode, otherwise uses eval mode.\n",
"\n",
" Returns:\n",
" loss: float\n",
" \"\"\"\n",
" inputs, labels = data\n",
" logits = predict(params, inputs)\n",
Expand Down
Loading