diff --git a/examples/consistency_models.ipynb b/examples/consistency_models.ipynb index 09ac6a5..2b1ed65 100644 --- a/examples/consistency_models.ipynb +++ b/examples/consistency_models.ipynb @@ -4,7 +4,7 @@ "metadata": { "colab": { "provenance": [], - "authorship_tag": "ABX9TyMSVH2s2QGB6w2mhmEv6Y1D", + "authorship_tag": "ABX9TyNMuCZYN/WQPLvMBhitJgGv", "include_colab_link": true }, "kernelspec": { @@ -65,7 +65,7 @@ { "cell_type": "code", "source": [ - "!pip install datasets wandb consistency==0.2.1" + "!pip install datasets wandb consistency==0.2.2" ], "metadata": { "id": "IpmvA2RSUctd" @@ -89,7 +89,7 @@ "source": [ "DATASET_NAME = \"cifar10\"\n", "RESOLUTION = 32\n", - "BATCH_SIZE = 360\n", + "BATCH_SIZE = 300\n", "MAX_EPOCHS = 600\n", "LEARNING_RATE = 1e-4\n", "\n", @@ -171,7 +171,7 @@ "source": [ "from diffusers import UNet2DModel\n", "from consistency import Consistency\n", - "from consistency.loss import LPIPSLoss\n", + "from consistency.loss import PerceptualLoss\n", "\n", "consistency = Consistency(\n", " model=UNet2DModel(\n", @@ -197,7 +197,11 @@ " \"UpBlock2D\",\n", " ),\n", " ),\n", - " loss_fn=LPIPSLoss(net_type=\"squeeze\"), # could use MSELoss here, but the sample quality is ⬇️\n", + " # You could use multiple net types. \n", + " # Recommended setting is \"squeeze\" + \"vgg\"\n", + " # loss_fn=PerceptualLoss(net_type=(\"squeeze\", \"vgg\"))\n", + " # See https://github.com/richzhang/PerceptualSimilarity\n", + " loss_fn=PerceptualLoss(net_type=\"squeeze\"), \n", " learning_rate=LEARNING_RATE,\n", " samples_path=SAMPLES_PATH,\n", " save_samples_every_n_epoch=1,\n",