diff --git a/README.md b/README.md index f9d8f55..5158619 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Our main objective is to present the core idea of the proposed method in a minim
-With just **100 epochs** of pre-training and a fairly lightweight Autoencoder architecture we achieve **44.25%** accuracy +With just **100 epochs** of pre-training and a fairly lightweight Autoencoder architecture we achieve **46.85%** accuracy with linear probing on the **CIFAR-10** dataset. Our training logs and encoder weights are available inside the [`encoder_weights_logs`](https://github.com/ariG23498/mae-scalable-vision-learners/tree/master/encoder_weights_logs) directory. For comparison, we took the encoder architecture and trained it from scratch (refer to [`regular-classification.ipynb`](https://github.com/ariG23498/mae-scalable-vision-learners/blob/master/regular-classification.ipynb)) in a fully supervised manner. This gave us ~76% test top-1 accuracy. diff --git a/encoder_weights_logs/linear_probe_211120-044529.tar.gz b/encoder_weights_logs/linear_probe_211121-072602.tar.gz similarity index 53% rename from encoder_weights_logs/linear_probe_211120-044529.tar.gz rename to encoder_weights_logs/linear_probe_211121-072602.tar.gz index 338e936..bcd1fa6 100644 Binary files a/encoder_weights_logs/linear_probe_211120-044529.tar.gz and b/encoder_weights_logs/linear_probe_211121-072602.tar.gz differ diff --git a/encoder_weights_logs/mae_logs_211120-044529.tar.gz b/encoder_weights_logs/mae_logs_211121-072602.tar.gz similarity index 86% rename from encoder_weights_logs/mae_logs_211120-044529.tar.gz rename to encoder_weights_logs/mae_logs_211121-072602.tar.gz index 404f117..dc15512 100644 Binary files a/encoder_weights_logs/mae_logs_211120-044529.tar.gz and b/encoder_weights_logs/mae_logs_211121-072602.tar.gz differ diff --git a/mae-pretraining.ipynb b/mae-pretraining.ipynb index 892985f..f05339d 100644 --- a/mae-pretraining.ipynb +++ b/mae-pretraining.ipynb @@ -33,7 +33,7 @@ "base_uri": "https://localhost:8080/" }, "id": "80ZKaTtG9zw9", - "outputId": "944d16c1-42b0-47aa-d3f8-04593bf18faa" + "outputId": "467ea57f-03ac-43d5-e19a-a5a8ac99c393" }, "outputs": [], "source": [ @@ -91,7 +91,7 @@ "IMAGE_SIZE = 48 # We'll resize input images to this size.\n", "PATCH_SIZE = 6 # Size of the patches to be extract from the input images.\n", "NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2\n", - "MASK_PROPORTION = 0.6\n", + "MASK_PROPORTION = 0.75\n", "\n", "# ENCODER and DECODER\n", "LAYER_NORM_EPS = 1e-6\n", @@ -130,7 +130,7 @@ "base_uri": "https://localhost:8080/" }, "id": "lMOYr_h1_QY6", - "outputId": "bfe4e094-3b98-450b-955a-e25b8b80eac5" + "outputId": "e0201dc6-c933-4cf8-c54f-7fde8e3d27de" }, "outputs": [], "source": [ @@ -291,7 +291,7 @@ "height": 496 }, "id": "ptI3I2aMB_rS", - "outputId": "2dc21f01-9c8d-43c5-de0f-d16f2a613d93" + "outputId": "895d0714-f9ba-4235-ca8a-b3c21ec7091c" }, "outputs": [], "source": [ @@ -322,7 +322,7 @@ "height": 248 }, "id": "qv0Va_68CPmF", - "outputId": "6e7e902e-cf50-48e9-f3f6-d52ae3fcd8e8" + "outputId": "5dd6d051-2740-452b-f57c-2131b7b8a3b1" }, "outputs": [], "source": [ @@ -497,7 +497,7 @@ "height": 301 }, "id": "UlxangzdFwMJ", - "outputId": "bee08ebe-cc62-43be-f650-498cb4bd3c52" + "outputId": "72caf44d-71d8-4e95-b66d-f4a7a75b0c90" }, "outputs": [], "source": [ @@ -749,7 +749,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "oCOI8_9BX_6g" + }, "source": [ "# Model init" ] @@ -783,7 +785,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "TCDOA_kdX_6h" + }, "source": [ "## Training callbacks" ] @@ -850,7 +854,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "id": "s4PEJxppX_6h" + }, "outputs": [], "source": [ "# Some code is taken from:\n", @@ -901,7 +907,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "id": "iTn6VcaBX_6h", + "outputId": "c0df7a58-763f-4a1f-bbf3-1e1f2e64784f" + }, "outputs": [], "source": [ "total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)\n", @@ -938,7 +951,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "-dUZ1RYqX_6h" + }, "source": [ "# Compilation and training" ] @@ -967,7 +982,7 @@ "height": 1000 }, "id": "ZUAXzpDoJiXG", - "outputId": "1488c9ba-08b1-4a11-970b-f9191d1af41a" + "outputId": "2015187b-0bfb-445e-a687-018a9284a649" }, "outputs": [], "source": [ @@ -984,7 +999,7 @@ "base_uri": "https://localhost:8080/" }, "id": "S0aVUm63Lj-L", - "outputId": "0634a201-0b1a-4fa3-c079-11ede2cae9d5" + "outputId": "cbe80727-6cc3-4db3-ec5d-cf01469f38d0" }, "outputs": [], "source": [ @@ -1010,7 +1025,7 @@ "base_uri": "https://localhost:8080/" }, "id": "kXnO5jNALndF", - "outputId": "2f1dc3e9-70f7-44b4-e4cf-9a15ea3e0bf2" + "outputId": "8b6c45a7-10bf-4e58-f866-f691b5352b12" }, "outputs": [], "source": [ @@ -1084,7 +1099,7 @@ "base_uri": "https://localhost:8080/" }, "id": "xdeuZ98oLvis", - "outputId": "75bdb66e-ffcb-454a-e4ff-1bb29e64cd73" + "outputId": "c9f7b566-47e9-4afa-a86d-728df915cd99" }, "outputs": [], "source": [ @@ -1116,7 +1131,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "0o8DHVfAMrSL" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0o8DHVfAMrSL", + "outputId": "c98b307d-c002-4507-ac21-0ca777cf3f06" }, "outputs": [], "source": [ @@ -1127,10 +1146,9 @@ "metadata": { "accelerator": "GPU", "colab": { - "authorship_tag": "ABX9TyMop/SYrAmInThOJSwKchR0", "collapsed_sections": [], "machine_shape": "hm", - "name": "mae.ipynb", + "name": "mae-pretraining.ipynb", "provenance": [] }, "environment": { @@ -1140,7 +1158,7 @@ "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84" }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1154,9 +1172,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.9.6" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 1 }