Skip to content

Commit

Permalink
docs: update colab notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 24, 2023
1 parent eadb294 commit 969c451
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions examples/consistency_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMSVH2s2QGB6w2mhmEv6Y1D",
"authorship_tag": "ABX9TyNMuCZYN/WQPLvMBhitJgGv",
"include_colab_link": true
},
"kernelspec": {
Expand Down Expand Up @@ -65,7 +65,7 @@
{
"cell_type": "code",
"source": [
"!pip install datasets wandb consistency==0.2.1"
"!pip install datasets wandb consistency==0.2.2"
],
"metadata": {
"id": "IpmvA2RSUctd"
Expand All @@ -89,7 +89,7 @@
"source": [
"DATASET_NAME = \"cifar10\"\n",
"RESOLUTION = 32\n",
"BATCH_SIZE = 360\n",
"BATCH_SIZE = 300\n",
"MAX_EPOCHS = 600\n",
"LEARNING_RATE = 1e-4\n",
"\n",
Expand Down Expand Up @@ -171,7 +171,7 @@
"source": [
"from diffusers import UNet2DModel\n",
"from consistency import Consistency\n",
"from consistency.loss import LPIPSLoss\n",
"from consistency.loss import PerceptualLoss\n",
"\n",
"consistency = Consistency(\n",
" model=UNet2DModel(\n",
Expand All @@ -197,7 +197,11 @@
" \"UpBlock2D\",\n",
" ),\n",
" ),\n",
" loss_fn=LPIPSLoss(net_type=\"squeeze\"), # could use MSELoss here, but the sample quality is ⬇️\n",
" # You could use multiple net types. \n",
" # Recommended setting is \"squeeze\" + \"vgg\"\n",
" # loss_fn=PerceptualLoss(net_type=(\"squeeze\", \"vgg\"))\n",
" # See https://github.com/richzhang/PerceptualSimilarity\n",
" loss_fn=PerceptualLoss(net_type=\"squeeze\"), \n",
" learning_rate=LEARNING_RATE,\n",
" samples_path=SAMPLES_PATH,\n",
" save_samples_every_n_epoch=1,\n",
Expand Down

0 comments on commit 969c451

Please sign in to comment.