diff --git a/notebooks/c2st/02_failures.ipynb b/notebooks/c2st/02_failures.ipynb new file mode 100644 index 0000000..8310f85 --- /dev/null +++ b/notebooks/c2st/02_failures.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "39c4d6b8-9316-4eba-a648-3095ed4b5d22", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import ones, zeros, float32, as_tensor, tensor, eye, sum, Tensor, manual_seed\n", + "from torch.distributions import MultivariateNormal, Normal\n", + "from typing import Any\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "import numpy as np\n", + "from sklearn.neural_network import MLPClassifier\n", + "\n", + "import time\n", + "import IPython.display as IPd\n", + "from svgutils.compose import *\n", + "\n", + "from sklearn.datasets import fetch_openml\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.mixture import GaussianMixture\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a215b377-4ddb-4f7f-a6fa-2e2ea93ef155", + "metadata": {}, + "outputs": [], + "source": [ + "from labproject.metrics.c2st import c2st_optimal, c2st_nn, c2st_knn, c2st_rf, c2st_scores\n", + "from labproject.data import toy_mog_2d" + ] + }, + { + "cell_type": "markdown", + "id": "f3f6191a-b33e-4536-b511-310304c7708e", + "metadata": {}, + "source": [ + "# Visualize data and fit" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f79753b9-4d8b-4118-aa62-b74704e03745", + "metadata": {}, + "outputs": [], + "source": [ + "_ = torch.manual_seed(0)\n", + "_ = np.random.seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ff550e7b-23a3-429b-b68a-7d4c0d3aeb69", + "metadata": {}, + "outputs": [], + "source": [ + "data = toy_mog_2d()\n", + "data_samples = data.sample((10_000,))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "243aa674-446d-4fb0-80eb-012024c472a7", + "metadata": {}, + "outputs": [], + "source": [ + "mean = torch.mean(data_samples, dim=0)\n", + "cov = torch.cov(data_samples.T)\n", + "gen_model = MultivariateNormal(mean, cov)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f66a5f10-e716-4fcd-a683-3836b7fb5b52", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAADJCAYAAAAgnhzUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT0klEQVR4nO2dXWgT2d/Hv2nTF/OCTfrertZa+i9K2dZdBV9gUSyisVRLBe90watdvFi8E2+8ci+8UfbCCxEXURAWF8Xu4pWIII/gPmsrPOxjtbWmptXaNtE2sW/O/C+yZ3JmMjOZSTKTOfF8QEwmk5nT88nvnN+cczJxiaIoguNoSgpdAE5muCQG4JIYgEtiAC6JAbgkBuCSGMBt9QlEUUQikQAAeDweuFwuq09ZdFgeSYlEAj6fDz6fT5LFMQdv7hiAS2IALokBuCQG4JIYgEtiAC6JAbgkBuCSGMDyYaFiZC6+qvla0Jv/KuWRZIK5+KquILJPvuGRZAArKt4MXJIOhZZD4JI00BI0FVuUPW+sqrS8LFySCkpBSjGlJck5sc+CiKnYouWimJKk9enOV0alJ6e0xIXxaFx6viHgRWmJC58F69eWMiHJTEaVrTAtQRMfUhOVlWXJZHhxRcB4NI4NAW9W5zKLoyVl03GT95iRRZ9HKWc0mkCJyox/a5VH9VhWXCc5VlKumdVcfDVjhelFD5EzOruYJqmt2vpkgcaRkowIIhWq12nriVITFI4l4HIlo2d0dhETs3G4FYaag+oRZCWOlKSFWpaVKSXOJJyOnrFYKnomZuN4N5dAWUlqUKYuuEb2XmXSYEVTBzAkiVRmqeKTTT83mxIrmzelnOnpBMrdyePX1qgnCV/kdZLaJ18piE6FaUi2ZaQp1BNE5PynPnm88dlUhieIyX924jhJSmhBtBySDhNIWgwkZZGoAtJlZRL0diyMcncJ3o4l92/YuA4rgoBVQUR7TSVaqzzSOezA8ZIAuaDKshL8835Bc9/WKo+mLOXx1ATNhd+g3F2C9nofAODFuwUsryZltFBNnl2CAIdLUlYuLWh0djFt/7bqSryKJZsmpSwC2aYliMhR0qzSJ9nRHwEm5pPu3LmDX375BaOjo7Ltly5dynuhaOgoUgqq9pQjMpdAZC6Bak85RmcXMTq7iBczi3gVS0jCxqNx6R8AvIrJBQ3/z/+rCkpGkYDmhuR2ZfAoP0RWjZobknT69GlcvHgRIyMj2LlzJ27cuCG9dvnyZUsKpoT0QURQZC6BZ29iaPRXoNFfgWdvYpIwkkYDkGRVlpVIj1/MyCOooqxEM4LW/6cFAKT+aFNtcj9llmklhpq7wcFBPH36FG63GydPnsS+ffvg9Xpx+PBhFOLL6ySCGv0VGH4dk71WF1yDyFwCzUGPJKqtulIWhUabOMKKIMj6I7sx3CeVlpYCADo6OnD79m0cOHAA9fX1Bf0qy/DrGNYF5P3CxNwn2XNaFgC8nknP4pyOoRLu378foVAIQ0NDAIAtW7bg+vXr6Ovrw/j4uIXFM8bfL2alf+9nkv3O9L+ySH9Fos9d4pIJyhRFTsCQpPPnz+PEiROybT09PXj48CFCoZAlBQOS2dNnQcSGgBeLK4K0fTaxLD3++8UsAGBDtQcbqpPjakpRz97E8OxNTNpW7nZlJchf4dZN/wFrkgfDsX7kyBF0d3fj48ePCIfDCIfD8Hq9OHfuXF4LlGn8S20EmsihHxNRJLFo9Fdgeu4TIm/1K5nw4t0Ctm/rwPKqKKXfQ5Pp71VmeFZg6jrp1KlTuHLlCqqrq6WEweVyYWxszJLCEciFY2uVR0qrgWSHXlvjxfhMPE3U+GwC6wKVaYlFuduFt2MThs47PptAc4MPjf4KKRmRXvt30s9xM7N3795FJBKBz2dfO95YVYmp2CI2BLzSdY4gpqYMphWJgpJ1gUpMRBelxyTC9HjxLhUxRDQ9Al5ZViJrfq3GlKSuri4sLS1ZLinodae17XQ0AfIRBxJNWpAMkPRf27d14PGT5zIZSrZv65ANrBYSU5L6+/vR3t6Ozs5OuN2pt96/fz/vBaNRiyZAHk1kKoFOyUkEke3ftFfLRLGCKUlnzpzBhQsX0NLSYlV5DNNWXYnR2UVZPzE990kmBgC6Wqow/DombadFZeKb9uq04xUCU5KCwSCOHTtmVVkMs6nWh3/eL6C7yYehyQU0Bz2IzCXSZk4BYGp+CV0tVQAgydKawFNeGBOUxyX9kSOnKnbs2IH+/n4cPHgQ5eXl0vZ8izNyraEmilDtSZXt2ZsYpuaXAKRXNpBM0QHIok2N5qBHNsBKMjvHzcwmEglUVVXh0aNHsu1WR5fetcj80iq6m1KJzNDkguxil8ijxdGQi1wikEgDIMklxyATfnbjsvr2nvF4XMoGFxYW4PXqD1RqLbNSTvwRyAiAv0L786Z2EUpDBBJhNEQQuYhWzsqqRVK+F6QYOlpvby8GBwfR2tqqOqBq1cWsniDl8Mz8knYTqbVOjoglEtWWawliMoKAlCCCHYIAg5J27dqFa9eu4ezZs3kvAI3aSlJaED3aoIyc+aVVaa5HCX3hqRaFRCJ9TFr6plqf7Bh29UUEQ5JGRkYwMjKCsbExvHz5EqFQCG63G/fu3cPmzZtx/PjxnAuit9SXyHkxk76aFICsTwKQNhpAPv3j0bjsNaVUOjrJa4srAhZXBOkYehldQdfdXb16FQCwZ88eDA8Po6amBgAQi8Vw6NAhSwpGZj7JTGqJC5orSocmF9KaNLXF9MptyqVhdFJAy1QKsquZI5g6ciQSQTAYlJ57PB5MTU3lXAi1KBqPxmWCInOJtBWlK0KyIpuDHmkaQZl9qX3yyQfAzLciCiUIMCmpt7cXe/fuxcDAAERRxM2bN3H06NG8F4peP0AETc99kq0oBeSrSulo0vvkT8UWs7oI1eqDrBYEZJGC37p1Cw8ePIDL5UJPTw/6+vp09zeSgisjaeKD+oIRMs1AD37W1nhRF1yD5qAHbdXpCxet7ODtEARkse5uYGAAAwMDVpQljRIXpOlueh7o8ZPnhkapi0EQwNB9HMikHpnyfvzkufTaqiDauj7bTkGAw1ewqkHmgBo2rsPyqigttyJNHWDddYzdcghMSGpu8GF8Ji7NAZFmrq7Og3qNgU8lhargfODokgtisilbEQQsr4qSnOVVEXV1HjTXeGUJA6CecrMsCHCwJFLpJHkAIF0jrQiCTBDxsm5t8j12DtnYgSMlqS06WRVEuEtcWBVEackvEdQW8GjO77AeRYBDJNELT8h6BkA+TCOIyagSRPm0AZCMoGIVBDhEkhrr1now8SGRNswjiHJ5Wk1csQgCHCqJRNO6tR5pqoIWo8zgilkQ4LCZ2Uw3XqIp5Fia3TjqL1IuijSTpRWjHILjhoWyqexiFgQ4LJIIpNKtvnUaKzj6r/3SZGjhuOaOkw6XxABcEgNwSQzAJTEAl8QAtvzOLCEez/x91S8NI7+9a7kk+rdl6+vrrT4dcxj5pglv7hjA8lFwQRAwMzMDgP+sthpG6sRySZzc4c0dA3BJDMAlMQCXxABcEgNwSQzAJTEAl8QAXBIDcEkMwCUxAJfEAFwSA3BJDMAlMQCXxAC2LEQh6xz4zGx2WB5JiUQCPp8PPp9PtiiFYxze3DEAl8QAXBIDcEkMwCUxAJfEAFwSA3BJDMC/3m2Q+ZXU7Qr8Zfz2no6AlqL1ml2yvmhJeiLMvN9qWV+cpFzFZDqmFcK+CElWiNE7V75FFa0kO8VonTtfsopOUrZyoktLsueBigqNPc2VJR+iikqSGUFKKQBQQk1Iqr2eD3HZUBSSjMqhK74kwwyx8nVBFKX3m5GVj2hiWpIROcqIyCRHC/K+bGXlApOSzMrREhNJRA2fs9kTkB2LyDIiKtdoYk5SJkFacrSElJdkHr5cFgREElFJFH1sO6KKGUlmokdLjpaQZ3OTmsf8OtgkvY8cSylLsPguC4665bQWRqNHTQ6pYD0RHneZ6vbE6or0+OtgE4BkVAFyUUSSXjQVbXOXTfQoI4eWoyUDAJ5Mj8qeb6trk/ZPrK7g2dykFFXK5s/qaHJsJGXT99DRoyZHKYImUJG6g390KbU+cFtdG4BUVNERZSaacokkJiVlat6III+7TCaGFkH437fDadu+bej69zxJWWqitJo9KyQ5srkzI8iMHDUhANC4Zq30eOrTB2k/IuvJ9KjU/BFRpNkz8zdlK8qRkrQwKyhQ4ZGJoWUAwF+jj5PH+ff51rbt0j5EllJUITAs6c6dOwiHwwiFQmhrSxX20qVL+OGHH/JWIK0oMiNITc5fo48lGYT1/lrZcyKNyJr69AFAUjbdT9mNoYUop0+fxsWLFzEyMoKdO3fixo0b0muXL1/OW2FyEfR/c+E0QY1r1iIy+Y9U+ev9tVjvr8Xk5AQmJyfw+PnfePz8b+k8RBrZH9BuIu3EUCQNDg7i6dOncLvdOHnyJPbt2wev14vDhw8jX3lHroIAefMWmfxHipzJyeSPCE9iQjruxsBX0uPHz//G9o5vACRFheffA4AsmnLFlsShtLQUANDR0YHbt2/jwIEDqK+vz8v3jcwIUl6cEkFj0RcAUk0bkJIDyKUo2Rj4SiYKSEbT1rbtWf09+caQpP379yMUCuHnn39Gd3c3tmzZguvXr6Ovrw+fP3+2tIBqF6l0HwSkBJHoWe+vlZqxudczsv+VbO3ulh4TUXQ0OQFDfdL58+dx4sQJ2baenh48fPgQoVAopwJkiiJAP0mgBQEpQRsDX0liOupaVf8BwF9DQwD0I80IeiMOuc4nGV7BeuTIEXR3d+Pjx48Ih8MIh8Pwer04d+5cTgXQLRzVlNKDo3QWB6RSa1oQqXwiQw2lKACyREIP5YgDYN1IuCnFp06dwpUrV1BdXS0lDC6XC2NjY1md3EwUAclmjr4OIkkCnY3R6Ami93k+/Up6vr3jG4Tn32Nr23YpaYguJbCtrk024GonpiTdvXsXkUhEGuaxEmUUkWaOCKITBbofIlFkRBAAmSA1vm3okl0j0UNCBKunKkwt2O/q6sKSygKNfKIVRUAqUaCbOa0oMgIRtLW7G2PRN1IUAZBFEYC0KDLa1Nm+Wqi/vx/t7e3o7OyE25166/37902fWG98TiuKtFDri4xCBAGQBDU3bQKQiiLS1BUiigCTks6cOYMLFy6gpaUlp5NmszZObcgnl75I2cwp+yKlIBo7owgwKSkYDOLYsWN5ObEaek2dErovUvJ8+pWmKLqJA6DazNH9kNr0BMGOKAJMStqxYwf6+/tx8OBBlJeXS9vNiMsURVppt1G2dnfjr6EhSYYyeyP7kCYOSG/m6HkkZTNHosjKST4lpo6USCRQVVWFR48eybZbGV2ZCM+/T4smehSB9FH0NiKIjiAiaGOgHQB0mzk7BQEmJV29ejWvJ6eJLi2ZWrg49ekDtrZtl/VLpPLp0QNaDr1PU9M6Q4JIFBkVZAWGJPX29mJwcBCtra2qA6rZXsyaZVtdG55Mj+Lbhi7ZFEJ4/r1scFRv1ICOHvqCVS1RUPZDVq8K0sLQEXft2oVr167h7NmzeS9ArpBoIhW/3l8rE6aEjh4iaGOgXVcQnc3ZLQgwKGlkZAQjIyMYGxvDy5cvEQqF4Ha7ce/ePWzevBnHjx83dLL5lVV4y/PTTJBoalyzVppSoGVpoRY9gHoTB6T3Q1pY+ZVMQ0cmfdGePXswPDyMmpoaAEAsFsOhQ4csKxxNYnVFGnGILiUQqPCkNXuZ5n+mPn2QpdhmBdmVKCgxdfRIJIJgMCg993g8mJqaylthBFFUTR6+DjZJow6kX6JFAcnB1kyzqFpySJoNpK9QLbQgwKSk3t5e7N27FwMDAxBFETdv3sTRo0fzUpBARUXa11SWBUG6ViKiPO4ymajke1Oy1CD7kX4HgCSHHDub5cN23SLA9OLIW7du4cGDB3C5XOjp6UFfX5/u/vTiyKloTHdxpDIN11rPTS8X1luVSqCXYinXd9PZm5k0284bbti6gtWsJEC+rgFQF6WH8oJU2awB9i6+zwbHLY5U9kvNnoAsokglZxoZpyHvIWT7zQjAfkFAAdaCCxlScLV13kB602eWXCIHKIwcgu1n9pe5dQdZSQKhFVF0ZasJ01ufrSaHnDNTmQuJ45o7QC4KSEWVch5HazpDuR+NGTlA4QUBBZKUKZqAVAWqRRVBT4YSs3JIOZ1AwUphRBQgl0WjN2KuNoRjZtTaKXIIziqNDnQl001hpn3N4DQ5hIKWymg0Kcn3XI5T5RAKfg/WQldQoc9vhIJLAgpTUf4yNxOCAAf1Sdk2fWbPwSKOKjVdibkIY1WGFo79a4qtonPBEX0SRx8uiQG4JAbgkhiAS2IALokBbPmdWUI8Hrf6dMxh5Ld3LZdE/7ZsfX291adjDiO3l+PNHQNYvhBFEATMzCRvesF/VjsdI3ViuSRO7vDmjgG4JAbgkhigKCV9//33OHfuXFZ3EJucnERTU1PmHW2kKCUBQFNTE/78809T7/njjz+we/fuvH7nKh8UhSRRFHHq1Cm0t7dj9+7dGB1Nfh1mw4YNAJKR9eOPP2LLli1Yv349rl+/joGBAbS1teGnn36SjnP58mX8/vvvBfgLMiAWAb/99pv43XfficvLy+L09LTY0NAgXr16VWxpaRFFURSPHz8u9vX1iaIoir/++qu4du1a8d27d+LHjx9Fv98vRqNR2fGcVi1FEUkPHjzAwMAAysrKUFtbq9oXHTx4EADQ0tKCzs5O1NXVwe/3IxgMIhaL2VxicxSFJJfLJRvILStL/4IZfZsd+g5jLFAUknp6enDz5k0sLS0hFovh3r17hS5SXmHrI6XBoUOH8OTJE3R2dqKhoQGbNm0qdJHyCh+7Y4CiaO6KHS6JAbgkBuCSGIBLYgAuiQG4JAbgkhiAS2IALokBuCQG+C/dW2Gn4U1nXQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = torch.manual_seed(0)\n", + "\n", + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " n_samples = 10000\n", + " samples_mog = data_samples\n", + " samples_np = samples_mog.numpy()\n", + " samples_normal_approx = gen_model.sample((100_000,)).numpy()\n", + " \n", + " n_plot = 10\n", + " samples_to_plot = samples_np[0:10,:]\n", + " \n", + " al=0.8\n", + " ms=8\n", + " mec='k'\n", + " \n", + " densities = [samples_np, samples_normal_approx]\n", + " cmaps = ['Blues', 'BuGn']\n", + " \n", + " fig, axs = plt.subplots(2, 1, figsize=(0.9, 2.15))\n", + " for i_a, ax in enumerate(axs):\n", + " density, cmap = densities[i_a], cmaps[i_a]\n", + " sns.kdeplot(x=density[:,0], y=density[:,1], fill=True, thresh=0.05, levels=10, cmap=cmap, ax=ax, alpha=al)\n", + " ax.set_xticks([]); ax.set_yticks([])\n", + "\n", + " ax.set_xlim([-7,4]); ax.set_ylim([-6,4])\n", + "\n", + " axs[1].set_xlabel('dim1')\n", + " axs[0].set_ylabel('dim2')\n", + " axs[1].set_ylabel('dim2')\n", + "\n", + " plt.subplots_adjust(hspace=0.3)\n", + "\n", + " plt.savefig(\"svg/fig2_illustration.svg\", bbox_inches=\"tight\", transparent=True)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "52f43874-52df-4ce4-9579-ef66a68fb546", + "metadata": {}, + "source": [ + "# Failure modes of c2st" + ] + }, + { + "cell_type": "markdown", + "id": "6d4745a2-51ae-428d-a0a0-4cc746bc51db", + "metadata": {}, + "source": [ + "### Too few samples" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "813f4fdb-e500-49f7-97b9-340937e15336", + "metadata": {}, + "outputs": [], + "source": [ + "c2st_gt = c2st_optimal(data, gen_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "742ca764-e77d-494c-a5cc-052e3c3b9c9d", + "metadata": {}, + "outputs": [], + "source": [ + "budgets = [10, 100, 1000, 10_000]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "507fde3e-dc26-4c24-9c01-c4e364dccab2", + "metadata": {}, + "outputs": [], + "source": [ + "estimates = []\n", + "for budget in budgets:\n", + " _ = torch.manual_seed(0)\n", + " estimates.append(c2st_nn(data.sample((budget,)), gen_model.sample((budget,)), seed=0).item())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4035cc12-7753-4760-8a67-4451d50b5602", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAALMAAABqCAYAAAD3JWAUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWFElEQVR4nO2deVBUV9qHn2aXfREURGASTCxAVIwK4kJQQcAlOvphGTVjTTJuEyduk6mxgqiTL1pi6ZiMUTNGXBKjqcTRMaFQM4LmA0bA4IKGZZSobAZla5Ctud8fSIeWBppma3rOU3WL7nPOPfel76/e+55zzyKTJElCINADDPraAIGguxBiFugNQswCvUGIWaA3CDEL9AYhZoHeIMQs0BuEmAV6gxCzQG8QYhboDV0Sc0VFBSNGjCAvL69V3q1bt/D392f48OHMnz+fqqqqrlxKIOgQrcWcnJxMYGAgWVlZavMXL17Mjh07+PHHH/Hy8uIvf/mL1kYKBJqgtZgPHjzIvn37cHFxaZX34MEDysvLmTJlCgBvvvkmJ0+ebLOuqqoq5SGXy3n06BFVVVWIMVCCzqC1mA8fPsykSZPU5uXn5zNkyBDldxcXFx4+fNhmXZaWlsrDysqKQYMGYWlpSXV1tbbm6QUKhYIdO3bg4+ODt7c3w4YN491336W+vr5X7YiOjiY6OrpV+ubNm7ly5Uqn6jp48CAnTpwA4De/+Q2xsbHdYGETPdIAbGxsbH0hA9HW7Cxvv/02V65c4fLly2RmZnLz5k3u3bvHe++919emAZCYmIhCoejUOUlJSdTW1vaIPT2iMFdXVwoLC5XfCwsLcXV1bbO8XC5XHsXFxT1hUp8hSRINDQ0dHs+HVAUFBcTGxnLkyBHs7e0BMDMz48MPP+TFF18EmjxmaGgo3t7e7Nmzh+zsbIKCgvD19SUgIIDU1FSgtQf08PAgLy+P2NhYIiMjCQsL46WXXiIyMpK6ujoAdu7cybBhwwgICODq1aut/q/Dhw+TlpbGm2++SUZGBkFBQcydO5fhw4eTkpKCTCZTlk1ISCAoKIj4+HjOnj1LVFQU3377LQBxcXGMHz8ed3f3LrerjLp0dhu4ublhbm5OYmIiU6ZM4dNPPyUiIqLN8hYWFj1hhk6gUCiIi4vrsFxYWBhGRr/cjn//+994eXnh4OCgUm7QoEG89dZbyu/V1dVkZmYCMG7cODZu3MiCBQtISUlh/vz5ZGdnt3vdpKQkbt26haWlJePGjSM+Ph5nZ2f+/ve/c+3aNQwNDfH392fcuHEq5y1btowjR44QHR3NqFGjAPD29ub06dNtXis0NJTZs2cTFBREeHg4p06doqqqiuTkZEpKSvDw8OAPf/gDVlZWHf5e6uhWzxweHk5aWhoAJ06c4E9/+hNeXl5cvXqVrVu3duel9B5JklS8W3x8PKNGjWLUqFEMHjxYmR4QEAA0Pd1ycnJYsGABAP7+/tjb27fZ29RMYGAgNjY2GBoa4uPjw5MnT0hISCAiIgIrKyvMzc2VdXbEhAkTOvtv8tprr2FgYICTkxOOjo48efKk03U002XP3LKPufnRAeDj40NycnJXq+/3GBoaEhYWplG5lowZM4bbt29TXl6OjY0NoaGhhIaGAqiI3NzcHFDfTpEkifr6emQymUoY0xxKQFPo0kxzuefLGxsbaxQbN9vS8voymUzles/T8mn0/HU7i2iV9TAymQwjI6MOj5YCBXB3d+eNN95gyZIlPH78GGgSx+nTp9U2pq2trXnxxRf58ssvAUhJSaGgoAAfHx8GDhzI9evXAfj+++9V2jPqmDp1KmfPnqWsrIza2lr+8Y9/qC1nZGREQ0OD2ryW12y2qaNzukqPxMyC7uGjjz5i7969hISE0NjYSE1NDSNHjlQ27J7n+PHjrFixgq1bt2JiYsJXX32Fqakpq1atYuHChfj4+DBmzBj8/Pzave6oUaNYt24dY8eOxd7eHjc3N7XlIiIiWLFiBYcPH26VFxMTw9y5c3F0dCQsLIycnBygKW7+4x//qHVc3B4yXZudXVVVhaWlJdAUB+pz41DQvYgwQ6A3CDEL9AYhZoHeIMSso+Tl5SGTyVq9cGl+e9eTPN+zoi3Lli3jp59+6pa6NEGIuYeRJImndXUdHura4cbGxixfvpzy8vI+sLzrXLp0qVdHPoquuR6mpr6eOTG7Oyx3ZsNaBpiYqKS5uLgwffp01q1bx6FDh1qdExMTw2effUZjYyOTJ09m9+7dPHz4kKCgIKX3jo2NJSEhgdjYWDw8PBg3bhwZGRn861//4uOPP+bixYuUlZVhZ2fH6dOncXZ2VmtfbGwscXFxVFRU8J///IfRo0dz7NgxTExMOH78OLt370ahUODt7c3+/fvZu3cvBQUFhIeHk5CQgJOTU+d/vE6ikWfOz8/vaTsEbbBr1y4uXrxIfHy8Svr58+dJTk4mNTWVjIwMamtr2b9/f4f1hYSEkJ2dTU1NDZmZmSQlJZGVlcXLL7/M559/3u65SUlJfPHFF9y5c4fc3Fzi4+O5c+cO+/fv5/vvvycjIwMvLy+2bNnCpk2bcHFx4dtvv+0VIYOGnnnWrFlcu3atp23RS8yMjTmzYa1G5dRhbW3NJ598wltvvcXNmzeV6efPn+fq1au88sorANTU1GBkZMTMmTPbvU7z+AlPT0/27NnDoUOHyMrKIikpiRdeeKHdc5vHcQDKcRzfffcdOTk5yjEi9fX1HdbTU2gkZh17r9KvkMlkrcKHzhISEkJISAjr169XpikUCtauXcu6desAKC8vRyaTUVpa2uY4DPhl/ER6ejqRkZGsX7+e+fPnY2ho2OF9VjeOQ6FQEBkZyd69e4Gml149NV65IzQSc1FRUbuj3qKiorrNIIF6du3axYgRI5TjKoKDg4mKiuJ3v/sdAwYMYOHChcyYMYNly5bx5MkTCgsLcXJy4uuvv1YZZddMYmIiwcHBrFy5kvLyclauXMmsWbM6bVdQUBAxMTFs2rSJQYMGsXbtWiwsLNi9e3ePjsNQh8a9GZIktXkIep7mcKN5ytSsWbOYP38+48ePx9vbG3d3d1avXo21tTV//vOfCQgIIDAwEF9fX7X1RUZGcv36dXx9fQkODmbkyJHcvXu303aNHDmS6Ohopk2bhre3N0+ePGHbtm0AzJkzh/DwcHJzc7X/xzuBRmMz/Pz8ei1mFmMzBNqikWcW3lfQH9BIzC0H3QsEuopGYraxsWH9+vXKiY1r1qzB0tKSwMDAdpcQEAh6E43E/M477/D06VM8PDz45ptvOHnyJBkZGWzcuJG33367p20UCDRCowbgiBEjlB32y5cvx8TEhA8//BAALy8vbt++3W0GiQagQFs08swtJ1teunSJadOmKb+3N1lRIOhNNHpp4uDgwNWrV6moqKCgoEAp5sTExHYXdxEIehONxLxnzx4iIyMpLi5m3759WFhY8P7777N3717OnTvX0zYKBBqh9YTW3NxcHB0dlQNPugsRMwu0RaOYuaGhgb/+9a+sX79eueqjp6cnNjY2OrOIn0CgkZiXL19Oeno6Li4uLF26lO3btyvzvvnmmx4zTiDoDBrFzKmpqdy4cQOAJUuWEBwcjLW1NatWrRKvugU6g8bjmWtqajAzM8PJyYlz584xceJEXFxcum3yo0DQVTQKM5YtW8b48eOV8bKHhwfnzp1j5cqVHS6ZKhD0Fhp55nXr1imn5zRTUFDAP//5zzYX1RMIehuNPHNKSgoLFiygpqZGmZaRkcHs2bOZPXt2jxknEHQGjfqZg4KC2L59O/7+/irply9fJioqioSEhG4zqGU/c1lZmehn/i+h5TrNWtehSaGKiopWQgaYPHkyZWVlXTaiLc6fP68yiVKgH5RWVZFbVEzF0xqsB5jhOXgQSxcu7HK9Gom5vr6exsbGVotcKxSKHh1o9LSuDknsUqVXlMjlpN1/QAMyGmVQW1lJSWUlYx48xHto18b5aBRmrFmzBjs7O7Zs2aKSHhUVxf3797t1L7eWYcb13Fxsra27rW5Bz1OvUPBELqekUs7PFU1CbTrklFRWIm+5DMEz6RnKZIx+4Vf878L/6dK1NRJzZWUlERERFBQUMHbsWBobG0lPT8fZ2ZmzZ89iZ2fXJSNa0lLMd/LycHy2bZhAN2hsbKS0qopHFZU8qqjg54oKHlVUPvtbQam8Cm1eozlYWnJizeou2aZRmGFlZUViYiKXLl3ihx9+wMDAgN///vdt7tAq6L9IkkRlTY1SpI/KK5RC/bmikp8rK1Go2QyoJSZGRjhaWeFoY42zrQ2DbWwZZGuDs60tnyYkcuOn+yqCN5DJ+JWTY5dt1+ltIIRn1o6sgkK+Tk3jQcljhg50YN7YV3jZ5ZcFEWvq6nmkFKiqZ/25opKaDrYzNjSQ4WBphaO1FYNsbBj8TLCD7WxwsbXD3tICQwMDtW+HMx88ZMPxz5GARknCQCZDBsQsWYR3F8fGCzHrGVkFhWz56nSrR7236xCliCtbvC9oC1tzcxytrRlkY80gmyavOtjWBmc7W5ysrTE2MsJAy6EMmQ8e8tn/JXHv0c/8ysmR1ydO6LKQQYi5X1NbX09hWTn5paUUPCklv7SU6z/d79CzAliYmjLQyopBNtZNntXWlsE2Ngy2s8XZ1hYzY2MM+1lPklifuR9Q8fSpUqwFpWXkPymloKyUkopKjRtbA0xMWDMjhMG2NrjY2WE9YEDTI16PBooJMesIjZJESUXlM8GWNgm2tIz80lLk7YQFFqamuNjZ4Wpvx9CBDnz/YzZ3Hz1SKWMgk+HtOoSpPt49/W/0KULMvUxdQwOFZWXPPG0ZBc/EW1BaRn07W/o6WFkyxM4OV3t73AY64DbQAfeBA7G3tFTxsKPd3dU2sF6f2Pl9rfsbImbWgo56CwDkNTXkK0ODXzztzxUVbYYGRgYGDLa1YYi9Pa729rgPdGCogwNDBzpgYWqqcQzbUw0sXUeIuZNkFRSy9VlvQfMPJwNCfEfQoFAovW3F06dt1mFuYtIUGjjYNYnVocnTDrGzw0TNPtoCzehSmPHll18SHR1NXV0dixcvZvPmzSr58fHxvP7668q1NUaPHq12n2VdoFGSqK6tpar5qPnls7zF5/S793j+lYEExN+42apOe0sLhtjZMcTBHjf7Z6GB40AcLC3b7IcVaI/WnrmoqIjx48eTlpaGra0tYWFhbNy4kdDQUGWZbdu2YWlpydq1He/p0UxLz7z64Ccsmjy51SO8LZq3KauqrUXeQpC/iLGmlTiby1TX1mr1GrYZY0NDZvmNxm3gQIY6NMW1lmZm/a57qz+jtWe+cOECwcHBODo2vYZcunQpJ0+eVBFzamoq1dXVHD16FDc3N/72t7+pXQGpqqpK+Vkulys/37r3E+/ln2Te2DHYmJtT/Uyk1bV1SgFWPftcVVNDVX29cvCKthgbGWJhYoq5qSkWpqaYm5pgZWaGhZkZlqamJGXnUKhm2KuvhztLJrQYJitJ1LQTaghaY25u3rWnlaQlH3zwgbRp0ybl9wsXLkjTp09XKbN06VIpLi5OkiRJ2rdvnzRp0iS1dQHiEIckl8u1laMkSZKk9TOwUc1gk+fHOx85coQZM2YAsHLlSm7cuNFvdxsV6D5ai9nV1VW58xFAYWGhSghRU1PDBx98oHKOJEkYq9nvTi6XK4+WdRYVFankqTuKi4uV5YuLizXKb5nWclOa4uJilTx1ZZ5H3TU7Y2NH9j1vU0tbNNlQpz/Z17ytm9Zo69Lz8/Mld3d3qaioSKqrq5OmT58uff311yplXnrpJens2bOSJEnSp59+KoWEhHRYr1wu79Rjp6Py6vJbphUXF6vkt8xTV+b5o6s2dmTf8za1tKU9u/TJPk3RWsySJEmnTp2SfHx8pGHDhkkbNmyQJEmSfvvb30pnzpyRJEmS0tPTpXHjxkleXl7Sq6++Kt2/f7/DOjsr5r5A1238b7VPp1+a6OoqoLpu43+rfTonZoFAW0SPvkBvEGIW6A1CzAK9QYhZoDcIMQv0hn4n5rt377ZaXlegPzQ0NDBlyhTS0tI6fW6/EnNZWRkHDhxQ9lEK9I8tW7YwdOhQrc7tV2K2tbVlx44dQsx6ytGjR/H398fT01Or88WEVoHO8NVXX+Hs7ExaWhrZ2dl8/vnnnTpfiFmgM5w5cwaA6OhoZs6c2fkKum2URxcoLy+XfHx8pHv37inTTp06JXl5eUmenp5SdHR03xnXDqNHj5YaGxul48ePSzt27OjVa7u7u6v8XrpMb93fPo+Zk5OTCQwMJCsrS5lWVFTEhg0bSEhI4Pbt21y5coX4+Pg+tLI19+/fx9XVFZlMxuXLl5k4cWJfm6ST9Ob97XMxHzx4kH379uHi4qJMazm/0NjYWDm/UFcIDQ1lwoQJZGRkMGrUKI4ePcqKFStYtmwZI0eOxM/Pj+joaADy8/OZMWMG/v7+uLm5sXHjRgBiY2P59a9/TWBgIG5ubmzdupV33nkHX19fpkyZwtOnT8nLy8PLy4t58+bh7e1NaGgojx8/VrFFoVCwceNG/Pz88PX1Zdu2bQCUlJQwffp0xowZwyuvvKJ8hPc2vXl/+1zMhw8fbrXOc35+PkOGDFF+d3Fx4eHDh71tWpvEx8ezZMkSDhw4QGpqKiNGjODYsWPcuHGD69evk5SURE5ODtXV1Zw4cYIFCxaQkpLCrVu3OHjwICUlJQBcvXqVuLg4rly5wubNmwkLC+PGjRsYGBhw/vx5AO7cucPq1avJzMzEx8eHqKgoFVsOHTpEXV0d6enppKenk5yczLlz5/jss8/w9fUlPT2dY8eOkZiY2Ou/E/Tu/dXJBqAm8wv7mszMTFatWsWPP/7I8OHD8fT0pLa2lsmTJxMeHs7777+Pubk5GzZs4NKlS8TExHDr1i1qa2uVs9EnTZqEtbU11s+2upg6dSoA7u7ulJaWAvDCCy8o09944w0WLVqkYsf58+f54YcflGKtqqri5s2bhIeHExoayr179wgLC1M+KXSBnrq/OilmV1dXFU/y/PzCviY0NJQrV64wa9YsSktLaWhoIDAwkOTkZFJSUoiPjycgIIDExEQOHDhAbm4uixcv5rXXXuPixYvK/cZNTExU6lW3fVjLtMbGRgwNDVXyFQoFO3fuZN68eQA8fvyYAQMGYG5uTnZ2NnFxcZw7d45du3Zx584dnVh4pqfur265u2dMmzaN7777juLiYurr6zl27BgRERF9bZaSw4cP8+qrr5KRkcHMmTP54osv+Oijj5g3bx7BwcHExMTg5eVFVlYWFy5c4N1332XBggU8ePCA/Px8FO0skPg8ubm5pKenK68bEhKikh8cHMwnn3xCfX091dXVTJ06lQsXLrB9+3Z27txJZGQkH3/8MY8ePdKZmfE9dX910jO7uLiwc+dOpk2bRm1tLXPmzGHu3Ll9bZaS5ORkAgICALh27RoxMTGYmZkxfPhwfHx8GDBgAH5+foSFhSGXy1myZAm2trY4OTkxduxYjWYtN+Pg4MC2bdvIycnB29ubQ4cOqeSvWLGC3NxcRo8eTX19PQsXLmTOnDkEBgayaNEifH19MTIyYsuWLdja2nbnz6A1PXV/xbQpHSYvL4+goCDy8vL62pR+gU6GGQKBNgjPLNAbhGcW6A1CzAK9QYhZoDcIMQv0BiFmgd4gxCzQG4SYBXqDELNAbxBiFugNQswCveH/AciqeNhQG1g6AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "good_vals = [c2st_gt for _ in range(len(budgets))]\n", + "\n", + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " fig, ax = plt.subplots(1, 1, figsize=(1.6, 0.7))\n", + " _ = ax.axhline(c2st_gt, color=\"gray\", alpha=0.6)\n", + " _ = ax.plot(budgets, estimates, c=\"#458588\")\n", + " _ = ax.scatter(budgets, estimates, c=\"#458588\", s=15.0)\n", + " _ = ax.fill_between(budgets, good_vals, estimates, color=\"#458588\", alpha=0.1)\n", + " _ = ax.set_ylim([0.5, 1.0])\n", + " _ = ax.set_xlim([10, 12200])\n", + " _ = ax.legend([\"Ground truth\", \"Neural net\"], handlelength=0.7, handletextpad=0.4, labelspacing=0.1, loc=\"upper right\", bbox_to_anchor=[1.1, 1.2, 0.0, 0.0])\n", + " _ = ax.set_ylabel(\"C2ST\", labelpad=-5)\n", + " _ = ax.set_xscale(\"log\")\n", + " _ = ax.set_xticks(budgets)\n", + " _ = ax.set_xticklabels([r\"$10^1$\", \"\", \"\", r\"$10^4$\"])\n", + " _ = ax.set_yticks([0.5, 1.0])\n", + " _ = ax.set_xlabel(\"#samples\", labelpad=-8.0)\n", + "\n", + " locmin = mpl.ticker.LogLocator(base=10.0,subs=np.arange(0, 1.0, 0.1),numticks=12)\n", + " ax.xaxis.set_minor_locator(locmin)\n", + " ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())\n", + "\n", + " plt.savefig(\"svg/fig2_panel_a.svg\", bbox_inches=\"tight\", transparent=True)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c08cd877-0baf-4c94-9336-0fc37afae320", + "metadata": {}, + "source": [ + "### A too poor classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "22aa6528-e5f9-4b4c-a0de-261542c7c60a", + "metadata": {}, + "outputs": [], + "source": [ + "def poor_c2st(\n", + " X: Tensor,\n", + " Y: Tensor,\n", + " seed: int = 1,\n", + " n_folds: int = 5,\n", + " metric: str = \"accuracy\",\n", + " hidden_size: int = 5,\n", + " clf_kwargs: dict[str, Any] = {}\n", + "):\n", + " clf_class = MLPClassifier\n", + " ndim = X.shape[-1]\n", + " defaults = {\n", + " \"activation\": \"relu\",\n", + " \"hidden_layer_sizes\": (hidden_size * ndim),\n", + " \"max_iter\": 1000,\n", + " \"solver\": \"adam\",\n", + " \"early_stopping\": True,\n", + " \"n_iter_no_change\": 50,\n", + " }\n", + " defaults.update(clf_kwargs)\n", + " \n", + " scores_ = c2st_scores(\n", + " X,\n", + " Y,\n", + " seed=seed,\n", + " n_folds=n_folds,\n", + " metric=metric,\n", + " z_score=True,\n", + " noise_scale=None,\n", + " verbosity=0,\n", + " clf_class=clf_class,\n", + " clf_kwargs=defaults,\n", + " )\n", + " \n", + " scores = np.mean(scores_).astype(np.float32)\n", + " value = torch.from_numpy(np.atleast_1d(scores))\n", + " return value" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6e6dbc38-46a0-46e8-b9c7-433a7cf1a8c0", + "metadata": {}, + "outputs": [], + "source": [ + "budget = 10_000\n", + "\n", + "hidden_sizes = [1, 2, 4, 8, 16]\n", + "poors = []\n", + "for hidden_size in hidden_sizes:\n", + " _ = torch.manual_seed(1)\n", + " poors.append(poor_c2st(data.sample((budget,)), gen_model.sample((budget,)), hidden_size=hidden_size, seed=1).item())" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "39910e45-fcc9-46f1-88d4-13f58d193a9a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAK4AAAB1CAYAAADX9doCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAS9UlEQVR4nO2de1ATV9/Hv0kgCSGAgHIJoNRblYvEgBTFR1EZq6VSq1JaQQUvo7ajUwuMIi0itU+tWmsL03bqBRSs1fq+VlRaxA5q+2qLKIg+VistKHIJqAVMAoEk5/0D3IfILSABVs9nZie752zO7+zmm7O/3T3ndziEEAIKhWVw+7sCFEpPoMKlsBIqXAorocKlsBIqXAorocKlsBIqXAorocKlsBIqXAoreSrh1tXVwdPTEyUlJW3yrl+/Dj8/P4wZMwYLFiyAUql8GlMUih49Fu7Fixfh7++PW7dutZsfHh6OTz75BDdv3oSbmxu2bNnSYVlKpZJZFAoFqqqqoFQqQd9GUzqE9JCIiAhy/vx5MmzYMFJcXKyXd/fuXeLq6sps37lzh7zwwgsdlgWg3UWhUPS0epRnHJOeCj4lJaXDvLKyMjg5OTHbEokE9+7d66kpCqUNPRZuZ+h0ujZpXG7HXolCoWDWlUol7O3tjVEtyjOEUYTr7OyMiooKZruiogLOzs4d7m9ubm6MalCeYYzyOGzo0KEQiUQ4d+4cAGDfvn0ICgoyhinKc0qvCveVV15BXl4eAODQoUPYsGED3NzckJubi8TExN40RXnO4RAysJ45KZVKiMViAM2+L3UjKO1B35xRWAkVLoWVUOFSWAkVLoWVUOFSWAkVLoWVUOFSWAkVLoWVUOFSWIlBwi0rKzN2PSiUbmGQcOfMmWPselAo3cIg4Q6w7gwUimH9cSsrKzvt3RUfH99rFaJQDMHgjuS01aUMJAzq1iiTyXDlypW+qA/t1kgxCOrjUliJQcLNzMw0dj0olG5hkHCtrKwQFRWF3NxcAMDatWshFovh7+9Ph51T+gWDhPvuu++ivr4erq6uOHXqFA4fPoyCggLExMRgzZo1xq4jhdIGg27OPD09ce3aNQDAypUrwefzkZSUBABwc3PDjRs3eq1C9OaMYggGtbg8Ho9Zz8nJQWBgILPd2NjY+7WiULrAoOe4tra2yM3NRV1dHcrLyxnhnjt3rtNAHxSK0TAkwFhhYSEZO3YssbGxIfv37yeEELJlyxZiZ2dHcnNzezWYmUKhoEHvKF3S47gKRUVFGDJkCKysrHr1j0R9XIohGOTjajQafP7554iKisIvv/wCABg5ciSsrKzwwQcfGLWCFEp7GCTclStX4vLly5BIJFi8eDG2bt3K5J06dcpolaNQOsKgm7NLly6hsLAQALBo0SJMnz4dlpaWePvtt+nrYEq/YJBwCSFoaGiAUCiEnZ0dTp48icmTJ0MikYDD4Ri7jhRKGwxyFSIjI/HSSy8x/q2rqytOnjyJ1atX488//zRqBSmU9jCoxX3vvffg4+Ojl1ZeXo4TJ07ghx9+MEa9KJROMUi4v/32G0JCQpCens6kFRQUIDk52ajC1Wg00Gg0RiufMnAwMelecHyDnuMGBARg69at8PPz00s/f/484uPjcfbs2W4Z7YzWz3GPHDkCoVDYa2VTBi7dHZBrkMzr6uraiBYApkyZgpqamm4Z7A4NTU1Aq34SlGeDWqUKxferoWhQw8rMDCMduj9ZjUHCbWpqgk6nazNzjlarNWonG6+XXoKVhYXRyn/eKaqoxMmCqyh/8BASWxu8KvXCSEcHo9nT6XT4s7wCe37MwuPLPE/VgLy/78C79B7cXQzv92KQq7B27VpYW1tj8+bNeunx8fG4e/cuUlNTu1P/TmntKvxRUoIhNja9VvZA5VZ5Bf73Uh5K7z+Ay2BbzJvggxcljr1WPiEEao0GKnUj6hubl1sVFTj46wW0/vE5AKa5u2GQSASNTguNVgeNTgetVguNTsekabW6Vvn//dTqdPppj7/X8tmR1LgcDmQvuOLfb75h8DEZJNxHjx4hKCgI5eXlmDBhAnQ6HS5fvgxHR0dkZGTA2traYINd8bwJ91Z5BRL/5xgznSanZYmf/zpelDhCq9OhvrERKnUjVC2iUzU2QqVWo76xCapGNZNWr26d34j6pkZGrLoB/qLIVizGobXvGLy/wZ1sCCHIyclBfn4+uFwufHx88K9//avHFe2IZ124Op0ONSoVHigUePBIgSO/5aKinfsEEy4XPC4X6l58qsLhcGBmagozAR//KJTtipnP42GauxtMuDyYmPBgwuPClMuDCY8HUx63OZ3XnN6c1rLN5TLppjyTlvTmdZOW75ma8PBJxklcvXNXz6bRWty+hM3CJYSgrr6eEeUDhYJZf9iy/o9SBW07M292Bd+EB6EpHyI+HyIBH2Z8PswFApgJ+DDnC2AuEMBcKICIL4C5gA9zoRAiAR9igQDmAmFLHh8cDgdcDgdxh7/HleISPfH2REDd5T+l9xCd/i0IAB0h4HI44ADYsWgh3LvRt9soM0uyka78TEIIlGq1nigfPiHQhwolmrTaLm1xORwMMhfBVixGVV0dalX1bfZ50dERMXNegUjQLEq+iQk4LT9yb7xmD/OfhPziEnA5HD0BhU2e9NRld4a7izN2hC/Ewf+7gOKqarxgNwRhkyd1S7QAbXEB/NfPfLIdHO86DBqtlhGmusmwy7aVyAw2YjEGi8UYbGkJO0tLDLG0wBBLC9hbWcFWLIapiQm4HE6vtUA94T+l955aQP3FcytcjVaLkur7KJLLcTzvCmpUKoO+JxYKYCMWw1YsxhBLSwyxsMAQS0vYWTZ/DrGyhMDEBLxOJt1+EjYLqL94LoRLCEF13SPclstRVCnHX3I5Sqrvd3lZF5qaYsX0gBZhNotSxOc3t4q0V1y/8kz6uCq1GkXyKhTJ5firUo4ieRXq6tv6keYCAUba26Oytgby2jq9PC6HAw8XZ8zxlvVVtSndgPXC1ep0uPvgQUtLWoWiSjnK/vmnzX48LhdDB9titIMDxjhJMNZJAhcbG5iamHToZxr7RoXScwa0q/DON7uxcMoU5u6eEIKHCkVza1opR5Fcjr+rqtHYzrPOwRYWGOVgjxcljnBzcsJIB3uYCwQdXuKpn8kunkq433//PRISEtDY2Ijw8HBs2rRJLz8rKwthYWFM7IXx48cjJSWl0zJbC3f6pkSY8PmY4eGOWpUKt+Vy1Cjb3kSZ8U0xws4eoyUOGCORYKxEgsGWFt26QaKwix67CpWVlYiOjkZeXh4GDRqE2bNnIysrCy+//DKzT25uLuLi4rBu3bpOy1Iqlcy6QqFg1rUtHXiyruQzaRwOB0421hhlb4/REkeMdnTAsMGDYcrj6bWmDe34tJSBjUgkMvymt6cBGQ4cOEAiIiKY7f3795PIyEi9febMmUNmzJhBpFIpCQ4OJqWlpe2WBYAudOlWAJgeX0vLysrg5OTEbEskkjYhR62trREdHY38/HzMmjULCxcu7Kk5CkWPHrsKunbetz/ZX3f//v3M+urVqxEbG4va2to20W9auwePHj2Co2PzzVhlZSXj7xobpVIJe/vmDs1yubxPI+j0l+2BZlckEhlcRo+F6+zsjHPnzjHbFRUVegHwGhoa8NlnnyE2NpZJI4TA1NS0TVkdnTCxWNwvIZjMzc37LfRTf9lmm90euwqBgYH4+eefIZfL0dTUhLS0NAQFBTH5QqEQqampOHHiBAAgJSUFfn5+3fpXUSgd0WPhSiQSbN++HYGBgXB3d4eXlxdef/11LF++HBkZGQCAQ4cOYcuWLXB3d0daWhr27NnTaxWnPN8MuBcQFIoh0Cf0FFZChUthJVS4FFZChUthJVS4FFYyYIVbV1cHT09PlJSU9JnNnTt3wsPDAx4eHoiMjOzzqbBiYmIQERHRZ/bS09Ph7u4Od3d3REdHG93ek7/pxYsX4efnB3d3d7z11lvdO9896F9jdC5cuEA8PDyIqakpKS4u7hObv//+O/Hw8CAKhYLodDoSHh5Odu7c2Se2CSHkzJkzZPDgwWTJkiV9Yk+pVBJra2sil8tJU1MT8fX1JdnZ2Uaz9+RvWltbSxwcHMjVq1cJIYS8+eabJDk52eDyBmSL+8033+DLL7+ERCLpM5vW1tZITk6Gubk5OBwOvLy8cPfu3a6/2As8fPgQcXFx2LhxY5/YA5rjvul0OtTX1zPhXM3MzIxm78nfNDs7GxMnTsS4ceMAAElJSZg3b57B5Q3IoTtddTY3BqNGjcKoUaMAAFVVVUhOTu7VmGidsXLlSnz00UcoLS3tE3sAYGFhgQ8//BBjxoyBSCTC1KlTMWmS8YYqPfmbFhUVwcLCAvPnz8ft27cxefJk7Ny50+DyBmSL25+UlJRg2rRpWLFiBQICAoxub8+ePXBxccGMGTOMbqs1hYWF2LdvH+7cuYPy8nLweDzs2LGjz+xrNBpkZmZi27ZtyM/Ph0ql0pvNqSuocFtRUFAAf39/rFq1CnFxcX1i8/Dhwzh9+jSkUini4+ORkZGBtWvXGt1uVlYWZsyYATs7OwgEAkRERPRqgO6ucHBwgK+vL0aMGAEej4c33ngDubm5Bn+fCreF6upqzJo1C0lJSVizZk2f2c3Ozsb169dRUFCAxMREBAcH44svvjC6XS8vL2RlZUGhUIAQghMnTsDb29vodh8zc+ZM5Ofn486dOwCAzMxMyGSGhwKgwm1h165dqKurQ2JiIqRSKaRSaZ+1uv3BzJkzER4eDm9vb4wbNw5qtRobNmzoM/suLi7YvXs3goODMWbMGFRVVen13e4K2juMwkpoi0thJVS4FFZChUthJVS4FFZChUthJVS4FFZChUthJVS4FFZChUthJVS4FFZChUthJc+NcEtKSuDq6tom3dXVFSUlJcjLy8Py5csN/h7QOxPl9WY5PaWjYx/IPDfC7QofH5/nNrYZG4+dCreFs2fPMiMe8vPzIZPJIJPJsHnzZmafkpISTJ48GVKpFKtWrWLSlUolli5dCplMBi8vL0YEqampCA0NxezZszF69GiEhoZ2OpK1rKwMs2bNgp+fH4YOHYqYmBgAwLRp05CZmcns5+npiVu3buGvv/7CzJkzIZPJMHHiRFy4cAEAEBERgVdffRVjx47F0aNH9Wzs3r0bXl5e8Pb2xoIFC1BfX88ce2NjI9OlUyqVwtraGuHh4QCaRwR7e3tDKpUiLCwMjx49eoqz3QsYZUjnAKS4uJiYmpoSLy8vveXxqNOcnBwydepUQgghHh4e5KeffiKEEJKYmEiGDRtGCCEkKCiIfP3114SQ5qkEHp++2NhY8umnnxJCCFEoFGT8+PHk6tWrJCUlhTg7O5Oamhqi0WiITCYjGRkZber2uJzt27eTPXv2EEIIqa2tJZaWlqS6upqkpaWR0NBQQgghly5dIpMmTSKEEOLv708uXbpECCGkqKiIuLq6kqamJrJkyRISFhbW7nmwsbEhNTU1hBBC4uLiSF5ent6xP+bKlStk+PDh5N69e+TGjRvE39+fqFQqQgghW7ZsIVFRUd04+73PgBwsaSwkEgkKCgr00p70X+/fv4+ysjJmEpaIiAjs3bsXQHOr/O233wIAwsLCsGzZMgDA6dOnoVQqceDAAQDN8QMKCwsBAP7+/kwEdg8PDzx8+LDD+kVHRyMnJwc7duzA9evXoVaroVQqMX/+fMTExKCurg6pqalYunQpFAoFcnNz9XzTpqYmZmRyRwMfg4OD4evri9deew3z5s2Dt7d3myE7lZWVCAkJQVpaGpycnHDs2DHcvn0bEydOZOwMHz68w+PoC54r4RoCh8MBadW3vnUE9dZ5HA4HPB4PQPNQ74MHDzJDT6qqqmBlZYVDhw5BKBR2WPaTREVFoaioCOHh4Zg7dy7OnDkDQgjMzMwwd+5cHDlyBCdPnsTHH38MrVYLoVCo90csKytjpiHoKIB2SkoK8vPz8eOPPyI8PBwJCQl6keTVajXmzp2L9evXM+LXarUIDQ1lhhQplUqo1equT6YRoT7uE9ja2sLV1RXHjx8HAHz33XdMXmBgIDNk/fjx42hoaAAATJ8+HV999VXznMHV1ZDJZPjjjz+6bTs7Oxvr169HSEgISktLUVZWBm3LfMPLli1DQkICAgICYGFhASsrK4waNQrp6ekAgF9//RXe3t7QtDNZ4WNUKhVGjBgBFxcXbNy4EYsXL0Z+fr7ePsuXL4ePjw9WrFjBpAUEBODYsWOQy+UAgHXr1uHDDz/s9vH1JrTFbYf09HRERkYiISGBuTwCQHJyMhYtWoS9e/fC19cXFhYWAIBNmzbhnXfegaenJzQaDd5//31IpdI2bklXxMbGYtGiRRg0aBDs7OwwYcIE/P333xgxYgR8fHwgFAoRGRnJ7H/w4EGsXr0a27ZtA4/Hw9GjR8Hn8zssXyQSIS4uDlOmTIFIJIK1tTVSU1Nx+/ZtAMCFCxeYm7Dx48eDEIKhQ4ciIyMDCQkJCAwMhE6nw9ixY7sVA8EY0DFnLIAQgps3byIkJATXrl3r9+e+AwHqKrCAXbt2Yfr06UhKSqKibYG2uBRWQltcCiuhwqWwEipcCiuhwqWwEipcCiuhwqWwEipcCiuhwqWwkv8HZuUlL2u1iAEAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "good_vals = [c2st_gt for _ in range(len(hidden_sizes))]\n", + "\n", + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " fig, ax = plt.subplots(1, 1, figsize=(1.6, 0.7))\n", + " _ = ax.axhline(c2st_gt, color=\"gray\", alpha=0.6)\n", + " _ = ax.plot(hidden_sizes, poors, c=\"#458588\")\n", + " _ = ax.scatter(hidden_sizes, poors, c=\"#458588\", s=15.0)\n", + " _ = ax.fill_between(hidden_sizes, good_vals, poors, color=\"#458588\", alpha=0.1)\n", + " _ = ax.set_ylim([0.5, 1.0])\n", + " _ = ax.set_ylabel(\"C2ST\", labelpad=-5)\n", + " _ = ax.set_xscale(\"log\")\n", + "\n", + " _ = ax.set_xticks([1, 2, 4, 8, 16])\n", + " _ = ax.set_xticklabels([\"1\", \"2\", \"4\", \"8\", \"16\"])\n", + " _ = ax.minorticks_off()\n", + " _ = ax.set_xlim([1, 17.4])\n", + " _ = ax.set_yticks([0.5, 1.0])\n", + " _ = ax.set_xlabel(\"Hidden layer size\", labelpad=7.4)\n", + " plt.savefig(\"svg/fig2_panel_b.svg\", bbox_inches=\"tight\", transparent=True)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5e7594ab-d765-4f86-b114-6382b442ac56", + "metadata": {}, + "source": [ + "### High-D Gaussian behavior" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b6f297c6-2857-47d2-b6c9-a014d0430939", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJUAAABnCAYAAAAJ+ABdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAATMElEQVR4nO2de1RVZf6HHzzg4RwODijkgMIoaAkIHEBFQCwu3gID0gwbRbExHWksLGfSWhPVLFvNRZIpHSeXhZoGpKlcFUzHMXNm5GLazKiNAoqIopXC4XKA9/cHP/eKvIHuwwHcz1p7LfY5e7/vZ7M+593f924hhBAoKMhIP3MLUOh7KKZSkB3FVAqyo5hKQXYUUynIjmIqBdlRTKUgO4qpFGRHMZWC7MhuqqysLLy8vBg5ciRvvPHGTd/v2bMHBwcH9Ho9er2exMREuSUomBkLObtpLl68SGBgIEePHsXOzo5p06axfPlypkyZIl3z1ltvodPpSE5Ovmt69fX10t9CCAwGAzY2Nmi1WiwsLOSSrSAzspZUhYWFhIeH4+joiJWVFQkJCWRkZHS45l//+he5ubn4+fkRExPD+fPnb5ueTqeTDltbWwYPHoxOp8NgMMgpW0FmZDVVVVUVQ4YMkc6dnZ1vMo29vT0vv/wypaWlTJ06lWeeeUZOCQo9AFlN1dbWdnMG/TpmkZ6eztSpUwH45S9/yVdffcX3339/y/Tq6uqko6amRk6pCiZEVlMNHTqU6upq6by6upqhQ4dK542Njbz99tsd7hFCYGVldcv0bGxsOhwKvQNZTRUZGcm+ffuoqanBaDSyefNmoqKipO+tra356KOPyM7OBuDDDz9k/PjxaLVaOWUomBshM5mZmWL06NFi5MiR4uWXXxZCCPHss8+KXbt2CSGEKC4uFuPGjROenp4iLCxMVFZWdirduro6AQhA1NXVyS1bQUZkbVIwJfX19eh0OqA91lJehz0XpUVdQXYUUynIjmIqBdmxNLeAvo4QgkuXLtHY2IiTkxNWVlbU1dXR1taGra3tTe14fQHFVCaisrKSP/7xj2RmZlJTU4OjoyMhISF4eXlJ7XKWlpa4u7sTGBjI8OHDzaxYPpTan8wIIUhNTeXVV1+lsbGRfv36ERERQVBQkFQq3eh5+GEp5enpSXR0NBqNxiy65UQxlYw0NTWRkJBAZmYmAOHh4URFRXH9+nWgvcfhzJkzrFmzBoPBwMMPP8xvfvMbzp07hxACe3t75syZw8CBA835GPeNYiqZaG5uJjY2lvz8fKysrEhNTcXKyorq6mr69+9PbGwsHh4eAHzzzTc89dRTlJWVodPpyMrK4uTJk3z33XfY2NiQmJjIoEGDzPxE907fixLNQFtbGwkJCeTn56PVasnJyUGr1VJdXY1WqyUxMVEyFMCIESP4+9//TkREBHV1dcyePZvQ0FAGDx5MfX09mzdvlkq33ohiKhl48803ycjIwMrKis8++4yGhgYqKytRq9XMnTuXn/70pzfdo9Pp2L17N8HBwXz33XfMnDmTqKgoBg0axPfff09GRgYtLS1meJr7RzHVfZKfny8Nm/7LX/7CwIEDKSsrw8LCgqeeeuqWhrqBVqtl165duLm5cfbsWRYtWkR8fDwajYaqqioKCgq66zFkRTHVfXDx4kUSEhIAWLJkCdHR0ZIRIiIicHd3v2saDg4O7NixA2tra/Lz80lPT+fJJ58EoLi4mH//+9+mewAToZjqHhFC8Oyzz1JbW4uvry9/+MMf2LFjB62trYwYMYLg4OBOp+Xr68uaNWsAWLFiBQaDgQkTJgCQnZ3d6+KrezLVt99+S3FxMWVlZbcdtdnX+fDDD8nLy0OtVrN161aOHDlCTU0NWq2WmJiYLk/MWLhwITExMRiNRubNm0dISAhOTk40NjaSk5NDL6mkA100VV5eHqGhoTz88MMsXLiQRYsW4eHhQUREBHv27DGVxh7HhQsXWLZsGdA+O2jgwIF88cUXADz++ONS00dXsLCwYP369QwaNIiysjJWr15NbGwsKpWKU6dOcfz4cVmfwZR0up0qMTERBwcHEhIS8Pb27vDdiRMn2LBhA1evXmXTpk0mEdqT2qlmzpzJ9u3bGTt2LF988QUfffQRFy5cwMPDg1mzZt1X2lu2bGHu3Lmo1WqOHz9OdXU1+/fvR6vVkpSU1CtGyXbaVOfOncPFxeWO11RWVuLq6iqLsB/TU0yVk5PD9OnTUalUlJSU0NjYSH5+Pmq1mqSkJGxtbe8rfSEE06ZNY8+ePURERFBQUMBf//pXLl++jF6vJyYmRqYnMR2dfv3dzlCnT5+mpKSEkpISamtrKSkpkU1cT8NgMPCrX/0KgGXLluHm5sbnn38OtNf27tdQ0P4aXLt2LdbW1uzbt4+srCyio6MBKCsro7Ky8r7zMDX3Xfs7dOgQ2dnZ0pGTkyOHrh7JqlWrKC8vx8XFhddff529e/fS1NSEs7MzAQEBsuXj5ubGq6++CrSb187ODj8/P6A9rr3VVLgeRVcHtV+9evWmz77++ut7GB7fNcw98eHUqVOif//+AhDbt28X5eXlIiUlRaSkpIjz58/Lnl9jY6MYOXKkAMSyZctEfX29eOedd0RKSoo4cuSI7PnJSZdLKj8/P4qKim4YknfeeYeIiAhZjd7TEELwwgsv0NzczJQpU4iJiSEvLw8Af3//DrOy5UKtVpOWlgbAmjVrOHv2LOHh4QDs37+furo62fOUiy6bavv27SQnJ5OUlERISAglJSWUlpaaQluPITs7Wxp9kJaWxtGjR7l06RIajcakP6ipU6cSGxtLa2srzz//PH5+fjg5OdHU1MS+fftMlu/90mVTBQQEsGjRIjZv3kxFRQXJycl37N/q7TQ0NPDiiy8C7fGNs7Mz+/fvB9rHS5m6ip+amoq1tTUHDhzg008/5fHHHwfag/Y7LW5iTrpsqgkTJrBz505OnDjBtm3bSEhIYOnSpabQ1iP4/e9/z9mzZxkyZAivvfYaRUVFNDU14eTkhL+/v8nzHzZsGCtWrADgpZdews7ODr1eD/TcoL3TprqxVlR8fDxFRUW4uroyceJESktLaW1tBejR7/l74cyZM9LaD6tXr+bq1ascO3YMaG85765JC8uXL2f48OFUVVXx5ptvEhkZiVqtprq6muLi4m7R0BU6/V+ZP38+7733HnPmzOnwuY2NDatWreLdd99l3rx5sgs0F0IIli5dSlNTE+Hh4cyYMUMKzvV6fYeFR0yNRqORgvbU1FQqKiqkoP3zzz/vsDhcT6DTpsrMzMTS0pLAwEDGjh3LjBkziI+PJzAwkMDAQDQaDVlZWabU2q3s3LmT3NxcrKyseP/99/nnP/9JTU0NGo2GSZMmdbue6OhonnjiCVpaWliyZAkBAQFSh/PevXu7Xc+duKcx6seOHeP06dOoVCpGjBhxU1+gKejObppr167h6elJVVUVK1eu5Ne//jXvv/8+RqOR6dOnd0ssdSsqKirw9PTEYDCwceNGJk+ezIYNGwCYO3cubm5uZtH1Y7ocFDQ3N1NUVER6ejqbNm3i0KFDPTJYvB9WrlxJVVWV1LKdl5eH0WjExcVFatk2Bz/72c9ISUkB2oN2S0tLxo4dC0Bubi5Go9Fs2n5Il021YMEC/vGPf/Dcc88xf/58ioqK+lTt79ChQ6xduxaA9evX87///Y9Tp06hUqmYPn262RewTU5ORq/X8+2337J06VKpz/Hq1ascOHDArNpu0OXXn4eHB//5z3+k87a2Nry8vDp8Zgq64/VnMBjQ6/WcPn2axMRE0tLSWLt2LQ0NDTz22GM8+uijsud5L5SUlDBu3DhaW1vJysrCx8eHbdu2YWFhwYIFC7q1EnErulxSubq6curUKem8urr6rkNiegsrVqzg9OnTODs786c//YmcnBwaGhoYPHiwNLy3J+Dv788rr7wCwOLFixkwYAA+Pj4IIdi5c6fZX4NdNlVDQwN6vZ5JkyYxbdo0PD09OX/+POHh4VI1tzdSUFAgVds3btzImTNnOHnyJP369SMuLg6VSmVmhR357W9/i6+vL1euXGHBggVMmTIFW1tbrly5YvZRuF1eoOOtt94yhQ6zUl1dLbWx3ehj++CDD4D2cVKDBw82p7xb0r9/f7Zs2cKYMWPIz89n3bp1xMXFsXnzZoqLi3Fzc8PT09Ms2h74ae9Go5HIyEgOHjyIt7c3Bw8eZMuWLVy5cgV3d3d+/vOfmz04vxPr1q1jyZIlWFpacuDAAQwGA4cPH0atVrNw4UKzTJ9/4KdovfTSSxw8eBBbW1syMzMpKCjgypUr2NraEhcX16MNBe0xVXx8PC0tLcycOZNHHnkEFxcXmpqayMjIoKmpqds1PdCmWrt2LX/+85+B9k0Dzp07x8mTJ1GpVDz99NM9bhGQW2FhYcGGDRvw9vbm4sWLxMbGMm3aNHQ6HZcvX+bTTz+V+ma7iwfWVDt27JDGm//ud7/DycmJw4cPA/DEE0+YZOCdqbCxsWH37t04OjpSWlrK/PnzmTFjBpaWlnzzzTfdPm/wgTRVTk4O8fHxtLW1sXDhQqZMmSLVmMLCwvDx8TGzwq4zbNgwabWZvXv38sILL0iTWsvKysjLy+s2Yz1wpsrIyODJJ5/EaDQya9Ys5s2bR25uLgBBQUGEhoaaWeG9M27cOHbt2oVarWb37t0sX75c2gfo6NGj7N69u1u61B4YUwkhePvtt4mPj8doNDJ79mzmzp0rjbcPCgpi0qRJPT4wvxuRkZFkZ2ej0WgoKChgyZIlTJw4USqxtm7dSkNDg0k1PBBNCrW1tfziF79g165dALz44ouMHj1aGo4bGRlJcHBwrzfUDzly5AjTp0+ntrYWBwcH1qxZQ3l5OUajETs7O2bMmGGy7pw+bSohBJ988gnJycnU1NSgVqtZtWoVTU1NNDc3o1ariY2NZdSoUaaWbxbKy8uJi4ujrKwMgOeee45Ro0Zx7do1LCwsCAoKYuLEiajValnz7ZOmEkJQWFhISkoKX375JRYWFkyaNInJkydLQ55dXFyIjY3t9Yu23o3GxkZWrFjBu+++C4CjoyNJSUnS9zY2NoSGhuLv73/bLfK6Sp8y1YULF8jMzGTDhg18/fXX2Nvb4+/v3yH41mg0hIWFERAQ0CcXxr8dhw8f5vnnn5em03l7exMTE4OlZXtPnUajwcfHBx8fH5ycnO4rFJDdVFlZWaSkpNDc3MycOXN4/fXXO3xfVVXFnDlzuHjxIk5OTnzyySc89NBDd033x6bSarVUVlZSWlrK4cOH2b9/P2fOnMHZ2ZmhQ4cyYsQIHBwcpPs1Gg3jxo1j/PjxWFtby/nIvYa2tjYyMzNZtWoVx48fR6VS4efnR2hoKD/5yU+k62xtbXF3d8fV1RVnZ+cu9312+27vsbGxxMXFMW/ePDZu3EhhYSHbtm27ZXo/HNBfV1cnzS+MioqiqakJtVqNra0tAwYMwM7OTvrVSQ9nYcGwYcMYPXo0o0aNkq147+0IITh06BBbtmwhOzuba9euScPC3d3d6d+/f4frV65ciVar7XzpJecc+k2bNon58+dL5+np6SIxMVE6b25uFgMGDBDNzc1CCCGMRqOwtbWVzn8M/792gnKY/+jK+hXdutv7jY7aH+7NMmDAAC5fviynDAUT0JVpYLJueHS33d5v15p7u4D5h5NTr1+/jpOTE9D+mr2XJRAVukZ9fb0UT3Vler+spho6dCh/+9vfpPMf7/bu6OjItWvXaGlpwdLSkpaWFq5fv37bMT+3azbQ6XS9YgRBX6IrtcFu3e3dysqKRx99lI8//hiAjz/+mMcee0wJoPsa9xiT35a77fZeWVkpIiIihKenp5gwYYKoqKjoVLrmXvTsQeRe/+e9pvFToffw4DQpK3QbiqkUZEcxlYLsKKZSkB3FVAqy0+tMdezYMYKCgvD19SUsLIyKigpzS+qzvPLKK3h6euLl5cXq1as7f6PJGjlMxJgxY0RhYaEQQoh169aJ2bNnm1lR3yQnJ0dMnDhRtLS0CIPBIIYNGyb++9//dupeWbtpuoMvv/wSS0tL2traqKiowN7e3tyS+iRRUVFMnjwZlUrFpUuXaGlp6XTXWK8zlaWlJbW1tXh7e9PQ0NBjFvrqi1hZWfHaa6+xevVqZs2a1fkJtiYuRe+ZzMxMMWTIkA7H+PHjO1yTm5srXFxcREtLi5lU3pm0tDQRHBws2traxNGjR4Wrq6uora01t6wuU1dXJ8LCwsT69es7dX2PNdWtaG1tFRkZGR0+c3BwEJcvXzaTojvT1tYmwsLCRFpamvDy8hIFBQXmltRpTpw4Ib766ivp/L333hNJSUmdurdXmUoIIby8vERubq4QQojCwkLh4eFhZkV3pry8XOh0OrF48WJzS+kSWVlZIiQkRDQ3N4vGxkYRERFx0w/6dvS6mGrr1q0sXryYlStXYm9vz/bt280t6Y5UVFSg0+mknTF62op8t2PmzJmUlJTg6+uLSqVi1qxZnd7KVxmlYELq6+vR6/Wkp6eTmppKQEAADz30ECEhITzyyCPmlmcyel1J1ZtYvnw5kydPJjg4mOHDh6PX6wkNDe2Vq8p0BaWk6mZSUlKIjo5mzJgx5pZiMnpdN41Cz0cxlYLsKK8/BdlRSioF2VFMpSA7iqkUZEcxlYLsKKZSkB3FVAqyo5hKQXYUUynIjmIqBdlRTKUgO4qpFGTn/wDSlFUsOw7O2AAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " fig, ax = plt.subplots(1, 1, figsize=(1.2, 0.7))\n", + " x = np.linspace(-3, 3, 100)\n", + " gaussian = lambda x, mu=0.0: np.exp(-0.5 * (x - mu) ** 2) / np.sqrt(2 * np.pi)\n", + " _ = ax.plot(x, gaussian(x), label=r'$\\mathcal{N}(0, 1)$', c=\"k\")\n", + " _ = ax.plot(x, gaussian(x, 0.25), label='$\\mathcal{N}(0.25, 1)$', c=\"gray\")\n", + " _ = ax.set_xlabel(r\"$x_i$\", labelpad=-5)\n", + " _ = ax.set_ylabel(r\"$p(x_i)$\", labelpad=-5)\n", + " _ = ax.set_ylim([0, 0.5])\n", + " _ = ax.set_yticks([0, 0.5])\n", + " _ = ax.set_xlim([-3, 3])\n", + " _ = ax.set_xticks([-3, 3])\n", + " plt.savefig(\"svg/fig2_panel_c1.svg\", bbox_inches=\"tight\", transparent=True)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "14b9348b-9da4-48b1-bbb7-43d49d452d84", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 Tue Feb 6 17:05:47 2024\n", + "2 Tue Feb 6 17:05:47 2024\n", + "4 Tue Feb 6 17:05:47 2024\n", + "8 Tue Feb 6 17:05:47 2024\n", + "16 Tue Feb 6 17:05:47 2024\n", + "32 Tue Feb 6 17:05:47 2024\n", + "64 Tue Feb 6 17:05:47 2024\n", + "128 Tue Feb 6 17:05:47 2024\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = []\n", + "for dim in [1, 2, 4, 8, 16, 32, 64, 128]:\n", + " print(dim, time.ctime())\n", + " true = MultivariateNormal(0.0 * ones(dim), eye(dim))\n", + " model = MultivariateNormal(0.25 * ones(dim), eye(dim))\n", + "\n", + " c2st_optimal_score = c2st_optimal(true, model, 100_000)\n", + " \n", + " df.append(dict(\n", + " dim=dim,\n", + " c2st_optimal_score=c2st_optimal_score.item(),\n", + " ))\n", + "df = pd.DataFrame(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3e6c9080-63c2-4f57-9b96-4a1d6ca07323", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI8AAAB4CAYAAADL9KEyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAARS0lEQVR4nO2deUxUVxvGn2FREKgUW1AcAQuoIOKAiAtUBHFBWkBcQptqVJS6oFbcqqkEFWtxQ0NsKR202lpbbLC1VEuNLMYoogQqaKAsguzuUlBkgPf7g88bxxnwMswwA5xfMsnce+49572Xh7O+8x4BEREYDAXQUrcBjJ4LEw9DYZh4GArDxMNQGCYehsIw8TAUhomHoTBMPAyFYeJhKIxSxFNXV4cxY8agtLRUJi0vLw8TJ07EqFGjMG/ePDQ0NCijSIYG0GXxXL16FW5ubigoKJCb/sknnyAqKgr5+fmwt7dHZGRku3k1NDRwn/r6ety7dw8NDQ1gKygaCnWRxYsX06VLl8jS0pLu3LkjlXb37l2ysrLijsvKymj48OHt5gVA7qe+vr6rZjJUgE5XxXfs2LF20yorKzF06FDu2NzcHBUVFV0tkqEhdFk8HdHa2ipzTkur/Zayvr6e+97Q0AAzMzOV2MVQDioVj1AoRHV1NXdcXV0NoVDY7vUGBgaqNIehZFQ6VLewsMCAAQOQnp4OADh69Ch8fX1VWSSjG1GJeGbPno0bN24AAE6dOoXPP/8c9vb2yMzMxM6dO1VRJEMNCIg0cxzc0NAAQ0NDAG19IdakaR5shrmPU1FRgdTUVIVGwUw8fRixWAwLCwt4eXnB0tIS8fHxnbqfNVt9lNzcXDg6Okqd09bWRmlpaYcj4ldhNU8fJDk5GZ6enjLnW1paUFRUxDsfJp4+xPPnz7F27VrMmjULDx8+lEnX1taGjY0N7/yYePoI2dnZGDduHGJiYgAAoaGhOHLkCLS1tQG0Cefbb7/l3WQB6PrCqKqor69nC6NKoLm5mb766ivS1dUlADR48GA6d+4cl15eXk6pqalUXl7e6byZeHoxpaWlNGXKFO49BgQE0P3795WWv0rXthjdS0VFBQoLC2FjY4P09HSsXr0adXV1MDQ0xOHDh7FkyRIIBALlFchHYRUVFUpTK19YzdM5xGIxaWlpyfhCTZo0iYqKilRSJi/xODk5qaTwjmDi4U95eblc4WzYsIEkEonKyuU12iLNnEdk/J/8/Hy5vlMffPABdHRU1zPhlXNNTU2Hq+Hh4eFKM4jROcrKyrBt2zaZ852ds1EE3vM81NbEyf0w1ENiYiJEIhGuX78OPT09rjOs0JyNIvBp21ifR7N49uwZrVixgns/rq6uVFxc3KU5G0Xg1WwRq100hlu3biEoKAh5eXkAgM2bNyMyMhK6uroAoPra5lX4KKyqqkrFGpaF1TzStLa2UlxcHOnr6xMAMjMzo+TkZLXaxEs8DQ0NFBYWRteuXSMiojVr1pCBgQFNnjxZZVUkE0/bEDwlJYXy8vJo/vz53PuYMWMG1dTUqNs8fuJZvnw5rVy5kmpraykpKYlMTU2psLCQzpw5QwEBASoxrK+LR96kn46ODu3du5daWlrUbR4R8RSPg4MD9z0kJIRCQ0O5Yzs7O+VbRX1bPOXl5SQQCGQm/c6ePatu06TgNVR/uWwPAKmpqfD29uaOm5qautrtYrzGTz/9JHeQYmRkpAZr2ofXaGvQoEHIzMxEXV0dqqqqOPGkp6d3b+++l1NVVYX169cjISFBJq07Jv06DZ/q6ebNm2RnZ0cmJiZ0/PhxIiKKjIwkU1NTyszMVEmV2JearebmZjp8+DAZGRkRANLW1qbp06dzfR5tbW0Si8XqNlMGhf15CgsL6cmTJ8q0RYq+Ip7MzExydnbmnnXChAmUnZ1NRF1z1OoOeIlHIpHQoUOHKCwsjC5duiSV9sUXX6jEsN4unsePH9OqVau4jrGxsTHFxsZqzEiKD7zEs3TpUlq4cCHt37+frKysaM+ePVyaqpYueqN4ysvL6eLFixQTE0NmZmbc8y1cuFAj5m06Cy/xjBkzhvteW1tLo0ePpiNHjhARkUgkUolhvU08YrFYZvg9cuRISklJUbdpCsN7bauxsRF6enowNTVFUlIS3N3dYW5urly3xl7K+fPnsWzZMqlzAoEAf/75J6ytrdVklRLgo7ADBw6Qo6OjVH8nOzubBg8eTAYGBipRdW+oea5cuUKzZ89uN1xeamqquk3sErxqnrCwMLi4uEidq6qqwh9//IHffvtNmVru8RAR0tLSEBkZiZSUFABt0dBe9/TTyHmbTsJrhjkjIwPz589HY2Mjdy4nJwd+fn7w8/NTmXGazqsRJogI58+fh7u7O7y8vJCSkgIdHR0EBwejoKAAYrG4az+w00T4VE8eHh509epVmfPp6enk4eGh3Lrw/2h6s/XqwqVAICALCwvO3v79+9Pq1auprKxM6h5Nn7fpLF32JBw7dqyybJFCk8XT3sKlvr4+bdiwQS3+T+qAV59HIpGgtbVVJpJpS0tLn1oYraysRGJiIsRisdyFy1OnTsHf318NlqkHXn0eT09P7NixQ+b8jh074OrqqnSj1IW8KFllZWU4ePAgJk+eDKFQiLVr1+LmzZsy92pra2PcuHHdaa764VM91dXV0fvvv0/W1tYUFBRECxYsIGtra3J3d6dHjx6ppErs7mbr9T5MYGAgubi4yDRNbm5uFB0dTVFRUaStra3RC5eqhndkMCJCamoqsrOzoaWlBRcXF7z//vuq0nS3RgYrLy+HpaWl3KZIS0sLU6ZMwbx58zBnzhyYm5tzaRUVFSgqKoKNjU3PHzkpgnq12z6qrnlaW1vp2rVrtGXLFhIKhXIn8cLCwqi2tlbpZfcWenVMwpdRI2xtbSEUCtHc3IzLly8jMTERZ86c6TACaGfj8/VFem2Ilfj4eISEhKC1tRUCgQBubm7Iz8/HgwcPuGsMDQ3h6+uLwMBA3L9/H+vWrUNLS0vvmcRTMT2q5nm9JnmV58+fo6SkBEVFRcjKysKuXbvk5mtiYgJ/f38EBgbC29sbenp6XFqf78N0kh4jnp9//hnLly8HEUEgEMDf3x/GxsYoLi5GcXExqqqq3pjnwYMHsWbNGpVGjuhL9AjxFBQUYNSoUW/82fPAgQNhY2ODIUOGICkpSSqN9WGUT4/4FywqKpIrnKVLl8Lb2xvW1tawtraGiYkJ518UHx+PTz/9lPVhVEiPqXns7Oyk3Br41CSsD6NauhyH+fTp0xg9ejRsbW3lLmEkJyfjnXfegUgkgkgkwpIlSzpdxtChQxEXF9dplwahUIipU6cy4aiILtU8NTU1mDBhAm7cuAFjY2P4+Phg06ZNmDlzJnfNrl27YGhoiPXr178xv1e3za6vr8fgwYMBALW1tTAwMEBlZSWKi4thbW0ttXcpQ/kMGDDgzS7GXZlhPHHiBC1evJg7Pn78OC1ZskTqmg8//JCmTZtGIpGI/Pz8OvRlQTvumuzT/R8+s/pdarb47F789ttvY+PGjcjOzsasWbPw8ccfd6VIhgbRpdEWn92Ljx8/zn1fuXIltm7diqdPn2LgwIEy9766u/F///2HIUOGyC33ZTPWEa/ujvz69fLSXj8HgDsuKSnBe++9J/O9PXqDfQMGDOgwHeiieIRCIbf5LCC7e3FjYyOio6OxdetW7hwRcSHQXofv+pWBgUGn1ro6ul5eWkfHfMrtbfa1i8IdHiKqrKwkS0tLqqmpoaamJpo+fTolJiZKXTNixAgurszRo0dpxowZCpWlyW6pRH3Tvi67ZCQkJJCDgwPZ2trSxo0biYgoODiYfv/9dyIiysrKIldXV7K3tydPT0+6e/euQuX0xT+OMlGFfRo7ScjQfNhmbQyFYeJhKAwTD0NhmHgYCsPEw1AYJh6GwvR48ZSXl+Ojjz5CSEgITp48qW5z5FJSUiITokZTuHXrFhYuXIjQ0FBERkZ27malzBapke3bt9P169eJiGjmzJlqtkaWx48f0+bNm1UWTaSrpKWlcfEQfXx8OnVvj695ampquJX9VyPVawrGxsaIiorivCI1DQ8PD5iZmSEqKgpBQUGdurfHi2fYsGHcLyfkrfIzOqaxsRGrVq2CSCTCokWLOnVvj3CA74hly5Zh48aN0NPTw9KlS9VtTo9jx44dyMrKQl1dHU6fPg2xWMz/ZlW0o8rg6dOn5ODgQHfu3OHOJSQkkL29PdnY2FBERIT6jCNmH5ESVtVVwZUrV8jBwYF0dXW5h6+uriYLCwu6d+8eNTU10bRp0+ivv/5i9qnRPo3s88TFxeHrr7+WCmdy4cIFeHl54d1334Wuri4WLVqEX375hdmnRvs0ss9z7NgxmXN8/KW7C2ZfGxpZ88iDj7+0OumL9mnO070BoVCI6upq7vh1f2l10xft6zHi8fb2xsWLF1FbWwuJRIIffvgBvr6+6jaLoy/ap5F9HnmYm5tj37598Pb2xosXL+Dv7485c+ao2yyOvmgf82FmKEyPabYYmgcTD0NhmHgYCsPEw1AYJh6GwjDxMBSGiYehMEw8DIXpVeIpLS1Fv379uOCZdnZ28PX1RUlJCaqqqjB79uxutyk2NhaxsbFKzTMiIgIREREAAJFIBADIzMzEli1blFrOm+gxyxN8MTc3R05ODnccExODmTNn4tatWzh37ly327NixQqV5v/yWW/fvs1FDOsuelXNI4+X2wXExsbCysoKALB48WKsWrUKTk5OsLCwwI8//oi5c+fC2toan332GYC2LTA3bdoEZ2dnODo6cntZpKWlwdvbG3PnzoW9vT1mzJiBR48eobW1FStXrsTYsWPh7OzM1Qyv1hJJSUkQiURwdHREQEAA98e2srJCeHg4Jk6cCFtbWyQnJwMA8vLyMHXqVIwfPx4WFhaIjo6WeT6BQICHDx8iPDwcZ8+exc6dO+Hp6Sn1jzJmzBgUFBQo/d32evEAbS/v1W29gTbnqOzsbOzatQuhoaH45ptvkJOTg6NHj+LJkyeIj49HU1MTsrKykJWVhatXr3JbEmRkZCA6Ohq3b9+Gvr4+Tp48idzcXGRmZuKff/7BlStXUFhYiGfPnnHl3bt3DyEhIUhMTMTNmzfh5uaG0NBQLn3gwIHIyMjA3r17sW3bNgCAWCzG1q1bcf36daSnpyM8PFzu8w0aNAg7d+6En58fwsPDERwcjBMnTgAAbty4gbfeegsjR45U6jsF+oh4AEBfX1/q+KU7gqWlJRwcHGBqagojIyOYmJjgyZMn+Pvvv5GUlAQnJyeMHz8ehYWFyM3NBQA4ODjAwsICQFuf49GjR7CxscGLFy8wZcoUHDp0CLt375YKCpmZmQlXV1cu2GRISAguXrwoY8/L/ADgwIEDaG5uxp49e7Bt2zapgJ8dMXfuXKSnp6Ourg7ff/+9yn5V0ifEk5OTIxPds1+/ftx3ebvgtLS0YN++fcjJyUFOTg4yMjKwbt06AJDaZkkgEICIYGBggJycHGzfvh0PHjzApEmT8O+//3LXve7JR0SQSCTc8cs8X+YHAAsWLMCvv/4Ke3t7fPnll7yfV19fHwEBAUhISEBSUhIWLFjA+97O0OvFExMTg/79+8PLy6tT93l5eeG7776DRCLBs2fPMG3aNFy4cKHd6y9fvgwfHx94eXlh//79sLe3l+pnTJgwAdeuXUNJSQmANid1Dw+PDm24cOECdu/eDX9/fy7qbEtLi9xrdXR00NzczB0HBwcjIiICU6dOhZGREe/n7gy9brRVVVXFDV9bW1sxYsQInD9/XurF8mHFihUoKiqCk5MTJBIJgoKC4O/vj7S0NLnXu7m5YdSoUXBwcIC+vj6cnZ3h4+ODrKwsAG0xk+Pi4hAYGAiJRIJhw4YhPj6+QxsiIiLg7u4OY2NjjBgxAsOHD+fE9zoTJ07Ejh07sGnTJuzbtw8uLi7Q09NTaK8PvjBnsF4IESE/Px/z589Hbm7um/eQUJBe32z1RQ4dOgQvLy/ExMSoTDgAq3kYXYDVPAyFYeJhKAwTD0NhmHgYCsPEw1AYJh6GwjDxMBSGiYehMP8DGhwbml3Mhi8AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " fig, ax = plt.subplots(1, 1, figsize=(1.2, 0.7))\n", + " _ = ax.plot(df['dim'], df['c2st_optimal_score'], label='gt', color='k')\n", + " _ = ax.scatter(df['dim'], df['c2st_optimal_score'], label='gt', color='k')\n", + " \n", + " _ = ax.set_xlabel(r\"Dimensionality\")\n", + " _ = ax.set_ylabel(\"C2ST\", labelpad=-4)\n", + " _ = ax.set_xscale('log')\n", + " _ = ax.set_ylim([0.5, 1.0])\n", + " _ = ax.set_yticks([0.5, 1.0])\n", + " \n", + " plt.savefig(\"svg/fig2_panel_c2.svg\", bbox_inches=\"tight\", transparent=True)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "29aee1c5-fd0c-4153-91db-9dc4505914cd", + "metadata": {}, + "source": [ + "### MNIST behavior" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "b4090aaa-3067-4e40-90af-0228325dd5ea", + "metadata": {}, + "outputs": [], + "source": [ + "mnist = fetch_openml('mnist_784', as_frame=False, cache=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "012c157c-6767-49f3-83ce-9426116043a2", + "metadata": {}, + "outputs": [], + "source": [ + "X = mnist.data.astype('float32')\n", + "X /= 255.0\n", + "y = mnist.target.astype('int64')" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "ae75703b-4407-41ff-963d-6f348e2a5c3c", + "metadata": {}, + "outputs": [], + "source": [ + "mask = (y == 1)\n", + "X = X[mask]\n", + "y = y[mask]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "36faf67d-78b8-4cb9-bab7-c53d5a014c59", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
GaussianMixture(n_components=20, random_state=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "GaussianMixture(n_components=20, random_state=1)" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_ = torch.manual_seed(1)\n", + "_ = np.random.seed(1)\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)\n", + "gmm = GaussianMixture(\n", + " n_components=20,\n", + " random_state=1,\n", + ")\n", + "gmm.fit(X_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "7901813d-055a-4c84-90c3-0e162f45663b", + "metadata": {}, + "outputs": [], + "source": [ + "gmm_samples = gmm.sample(10_000)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "1291ae79-0bb9-47df-87a2-8298b6401685", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"many_gmm_samples.pkl\", \"wb\") as handle:\n", + " pickle.dump(gmm_samples, handle)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "c1ec0e40-1e6e-4030-8187-89ac3ec5e0c2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEI0lEQVR4nO2cTWxUVRTHf+/NTKHTYjvYdEiLWDQQmyh+xBhj4kISWVRcEGN0y8qFW6MJ0a3GhS5g48YFJq7cSBQTozFRg0KgaLDBYAMoY0OhNKXS6XQ678PFmSkztKd9085708X5bZp5776b2//733PuPXdaJwzDEGMZbrsHsFkxYRRMGAUTRsGEUTBhFEwYBRNGwYRRSEdt+KL7apzjSJTvgi/WbGOOUTBhFEwYBRNGwYRRiJyVkiT90BAAwbZOAPxsBwC3Hs8CkBtfZMtYQe5NTctDgd/SMZhjFGJxjJOWbkO/7i02USis5HsAuPymvLftvUUARnaeB+DsEyla64/lmGMUYnFM6HkNn51MB2FlMfLzfqcM6/7vJbZ88N5nAEx59wFwfs8L+ONXVu/EcaqDWV9J2xyjkEhWiuqWWmwqDmQASL12E4CB1B0A3r74CgB5vBWevrez6jsP1xeNzDEKm2odU4tN3RPisI+HPwdgb2YrAIs/9gEQXBsllcsB4M/MrNzZBtc1yQnTRDD0j8iibV+HCFKpTocdZ0rSRbmMX4kwnTaATSWFWB3jdnUBEBSLkZxSOfA0AM/0nWu4fmj8oPT38293L7Z4C3Av5hiFWB0TzM831f7q6/LzWO5XAOaDFACFE7sB2MH11g1uDcwxCvFmpSaW4+mhXXyz/ygAwx1SXvhweg8AfRfKrR/bGphjFDbNAu/iu/1LTvnXmwPgq4nHAEh1SqzJ1MoZXrxrGEhQGEf5pVJ7Hwbgree+Xbp2ZGIEAP94PwBdJ8/Iswl++cumkkKsjqmv5C2zf3WLcP1AHoAgdPHDAIDCnOyDKt3Sxu2U2m+z6X8jmGMUYnXMkktqG8iGmxIv+g9dA+BwzyU+mZX0PHlbKnVDZ2eBZJ1SwxyjkExWctxllbT04AAAb+z6AQAXl5M3JD17V7ql0fgFtUs3K6k9LjeZYxSScUx9iaAab/585wEAnt96Sz5X0mzfIm9/8CeJTau5Ie64Y45RSLy0WVvpHhs5DkAuJbHi08lnOXVJ7g2f+wcAv41/5pCIMKneHvzbknoLL8sy/6XsAgBzwcJSu+H3pbDt37i5an9OOh37fsmmkkJkx2ibwCgEpbuuKD1Zarj30fRTAEyWtq197FqlfgxORo5xmzkCjoI5RiGRGONms/hlqcIdfOQPAE4vSAr/uvAoAPOn+tjJVNN9N3zVpIWYYxQiO2YjWcCfmSG9+0EAToxKVrqzT04Z9w/+BcDpsd71dR7T+ZI5RiHeGFN3Xu1dlUVbz9ggAL+PyoYx/+VlALr4G3rlK2b+7H9Lz7ULc4xC4udK+aO/NDapnm+HnkdYTv78SKPtxydBsdjuIayITSUFE0bBhFFw7H87rIw5RsGEUTBhFEwYBRNGwYRRMGEUTBgFE0bhf7AsN3svJcxQAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEGElEQVR4nO2cT2gcVRzHP7Ozs9lNusk2TUINVpNg2iRUBSmGFlQkGBAUKaJHlXoQwUMQPNiTehb04E08qAehgp5sLSUooqn4DzVoxca2aSxJ/2iTbbKNuzszHt7MYtr91d1NZ3YPv89ld/a9nff4vu/7/d6bt6zl+76Pch2JZnegVVFhBFQYARVGQIURUGEEVBgBFUZAhRFI1lrxocQTUfYjVo55H/5vHXWMgAojoMIIqDACKoxAzVmpmVx+Zi8AF8ddANr71hiYWgGgvPBnJG2qYwRa2jGLL+4D4MCzhwEYazsHwKtzj17nlEQ6DYC3vn5T2lbHCETrGMsyr3U+VrbHdgLw+NOfAzC19cyG8qkvt9PVeQkAN583TZTLjfezCuoYgWgdU69TensBKAx2AXC+2Fm1Xnberzil0pQ6Jh5aKysFoz6/3zhtun8mKDDj98K5cQBy7x9vOH7VSksJ8/vBEQA+nngTANtq21A++9rdAKT5JvK+6FQSaBnHWHt2c/jJ1wHY6XRsKCv5ZivQMTMHgAuRTaEQdYxA8xwTBM/k7TsAODPZWXHKF8GqfrytBMDo9HMADP/1Q2zdU8cINM0xdjYLQOmWrQDse+ynSlkuYSzz0uJ9AAw/FZ9TQtQxAs2LMX3bAMgPZQB4e8dXlaK0ZbLQp8f2ADDI8Zg710xhHNP0hcli5aPF8ioAb118EIDBl+MXJESnkkDsjrGSQZNBun5g10kAThQLDDlmC3D0j1EABvj5JjVa/75KHSMQj2MsqzJa4XOTpftN8P3g1ncBuOJ5nCqZBV3uk44qN9kEDWwf1DEC8TjmPyNm77oDgCsD5vrIWj8A+7dcYGr+YQBy7zUvG4WoYwRiy0phNsrvNrHl+UeOApBOmLhyqlTil2lzOnAbl+Lqlog6RiBaxyRs85JyYGQIACsIN06w7J/M/A3AOyujdP/qmfrBBtNbNSvherKK3WlOFq49RaiXSIVJpBzzxnGwSkYIJ2/S9Z3pBQBWfTOV3vhuguFDXwPgRdmpGtGpJBCpY8IDdjuTIT+SA2Bpr1med9sFAF5ZmgCgYzZNcmgAgPLpeXMDKxg3z5UbuWa5v9kpFKKOEYgnXW/vYb3bjIHX+w8AJd9cH/n2LgDS7eB2ma2A3dMDgL+2Zr5TKMj3Dl3l38BVDaCOEYjFMe6Jk1w+uAWAVJvJSoeW7wWgfcF0oWe2jDV3FgDv6lWgxoP6G8WfTaCOEYjFMfa2bvo/SgGQnTE/F/t+7B4A+lLm0Wbmx7N4RfPed6NxQT2oYwTiyUquS/Y3s/T3V02msT8zZ0VJJxVUcSOLF40QT/BdXoHllaplfqlY9fNmo1NJQIURUGEELP1vh+qoYwRUGAEVRkCFEVBhBFQYARVGQIURUGEE/gUTKxCtHsL4lgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAEMklEQVR4nO2cT2wbRRSHv1nHdrCdOEFq0xS1lKiUpBLQUwEJjog/ogIJcQBUTqj0DipckDigSiAuICGO3DgAx6pFAoGKVJooIEFoS4GWJqraUlqlcZw6dmzvcphdFyd5dWM7Yxe977L27mj26Te/mXkzY9kEQRCgrMDrdADdigojoMIIqDACKoyACiOgwgioMAIqjEDPrRZ83HthPeNwytf+Fw3LqGMEVBgBFUZAhRFQYQRueVZygYknAAjKSwDEBgcBOHzyu7pyT/z2DLH9tmz1z7/WJRZ1jEBXOSZySsRyp3xftNfp41vYXrkIgJdOA+Bfv97WWNQxAl3lmIjYQLbue85fBGBTzAdg5PNrVM7N3LwSL2avfrWpGNQxAl3pmDNv7gw/HQUgH7b6gennAfCnTjesw3gGgCAw0MRBiDpGoKsck3v5YQAOvfRBeMfOOGlj2+/kxAgAI1xuWFdQqbQUS9cI4z0wyuyeAgA74um6Z4cLWwC49+ApAJobTtcYj4N33JZ03DFeKmWvV+Z448EpAEpBGYCkiQPw7tTTAGyd+9VdXM7edJvRcceYO3oBmH/kbvZlvwKgHNj2ihK7e94Jlwo7dwBQPfXHuseljhHomGNMMgmAv20YgCMffQhY9yz4JQAmS3ZpYArh6rHc2hS8FtQxAh1zjJexuUpho52VMl5v7VnKs7PRe9NPARAbHgDAHJ9yFp97YcJVr79g908eOji5osjFiu1KpU9sN0v/MOEouBtoVxJw75hwpRyU7PXIzBgAb28Yr3WnX5Y2AZAdPw9Addle8JppYm9GHSPgxDEmmSQo2XHD9NhXnj+wG4AXR75dUf71CXtOfl/lAtCCUyKa2MVTxwi4cYwxmHCxaDYPAfDN/vcBGO7JAFAO4rXyQ4ds8kdVbulYf78tMj/f9nhBHSPixDF+sVg7ZSyMbgDgctU6ZDiMIG5ijB3bC0AmZfdrq7NzYp3r5ZQIdYyAszwmmlny+3IA7AoXkQXf3k95CYayeQASv9887/D6+vDz+VWfLT//bhY3g288Qfmx+wF4a/RLAM6WFwDoC485UiT45+hmALYWrXim1yZ8frFYV1+wuCi/LPDbErN2JQEnjvEyac49awfbyQV7BNLvWRc8mbKJ39ixvQyeta3t5expQXWpvGp9/z0aifZ1ogSy1WOTWsxtqeV/iBPHBHcN8dmejwF49edXAHhu10/hU9s2pUsp+s/YcYecvXoJ6zK/6OIkqR51jICbBO/EaT69+igAizN9ALyGTea23TkLwMYJQ/DjCWBtJ43R2NJu1DECzhK86d0299jOeN39aN7JcunGzRZ/9NMO1DECzjaqYoMDAFSvzQENxoYOOiXCzXRdKlH5u/FvWroJ7UoCKoyACiNg9L8dVkcdI6DCCKgwAiqMgAojoMIIqDACKoyACiPwLweoIO1ayiZbAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " for i in range(3):\n", + " fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n", + " _ = ax.imshow(gmm_samples[i].reshape(28, 28), clim=[0, 1])\n", + " _ = ax.spines[\"bottom\"].set_visible(False)\n", + " _ = ax.spines[\"left\"].set_visible(False)\n", + " _ = ax.set_xticks([])\n", + " _ = ax.set_yticks([])\n", + " plt.savefig(f\"svg/fig2_gmm_mnist_{i}.svg\", bbox_inches=\"tight\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "6252d5f8-ea29-4a3b-8db0-e174f363d598", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)\n", + "cov = torch.cov(torch.as_tensor(X_train.T))\n", + "cov += torch.eye(len(cov)) * 1e-6\n", + "mean = torch.mean(torch.as_tensor(X_train), dim=0)\n", + "gaussian_fit = MultivariateNormal(mean, cov)\n", + "\n", + "gaussian_samples = gaussian_fit.sample((100,))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "8e99ad03-7d04-4c62-ba66-220bb515a0b6", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGbElEQVR4nO2cXWxURRTHf/fe3W23n0spFCgV+RCqVEVA8INEEtRIeDAxiokxGhM1Gp9888UnH+XBqAnxQV8ML35gjAkPIibEBCKCGD9iRQX5CqWUfu12u9vuvePDmbvtbncKbe9SQub/cvfemTl37pn/nHPmzGQdpZTCYgrc+e7AzQqrGAOsYgywijHAKsYAqxgDrGIMsIoxwCrGgNj1VnzMfaaa/bihOBh8fs06ljEGWMUYcN1T6UbAbWyUH74PQJDNzltf5k0xbm0tAOHi3lu8iGBBgxSeuSjPmpoA8IeHS9o6sRhuqlnaZ0cBCHJ5KQz8aPoXiZRbELNnjONAeSrHcUrvdbnb2Ijj6TFw5Dqy7Q4A6n/4C4CenR30bxkHoOmPu6VqQZq0fXCkROzw05vx8iK77qsfZ/0J08EyxoDZM6ZS4k8/c2Ii1qmpAcBtSYEfAOC3pQBI9ohtoG0RAMOrYXuXsOe9J74F4HhebM5r7a8CsOqtowD03+XQdFrYWTfrD5geljEGVMUrKe1uHX3F8/BbxcMECXnl8KokAC0n87ouxBxhVZ2TAGBHUtoXFoixCR65D4CxZkXzmXw1ul6EZYwB1YljQvvjeXI7nMZpFIb4zfLKQlJsRLpzAQDBihyBkmdxxysRt7XrXwCO7RZPpmoKuIdPVqXrISxjDIiUMaEXIhDGqDGJSxzPwxkTO+HXyFjU9Yn96F8nXaj9LcnJxnZpf1up3FeWHAYgs17kD71fVqEKiEYxYWCnFeLERawaG5PniThKTyt01fr/MnJbEJd8YYeLn09UFB8a4Ze7OwBYu786Qd1k2KlkQDSM0cY2dNNFhCxRCudSLwBJXeToadbwuzCnrnM5fsf0u8XtB51py6OEZYwB0TDGFWY4rh5RJYGak0jqa4LCpR4pu9pf0mb8UQnaHAVbl5+tKP7d/tUA1H85ybZou+Zqgx/kchF8yAQsYwyI1F2rgs4TaDYwKgtFPzc1fI91LAOgd614ouDhIeq9sYpyP97/OAArOFp85sTi0i6vZYeeMaJTLZYxBsyNMeWJqTIUGTQJxbyubpt5SPK6m9oukYpXzvG2Hy5jkusV7VixK5pBarwy62aKOSnGCddC5QqYJu8apNMAuIsWAtDWIvncxTVpliYGK7YJo+XJnS2GBuHUKVPUXGGnkgFzYkylqXIthNm908+L8d2z5hMAdtVNdbebTuwGYOGAnh6hUVfBFLZOCS5LXjpzw2wZY0B18jHTjFA4wt6GIQDe/OIlAHa9sJd9abE7VwpioBMxYUF2mexBNTXJglP5AU5CG9uRbIncInP0bgSBPysXbhljQHV3Iiswx9ncBcBInywXGq9MuPyvr2wA4KMV3wCw9/hOABbGdDpDu3p1tR+V017IlbF162S/IAz41PjM7d9kWMYYUN2c7yQ4m9YDcPkB2S1InpPnmdsnvMlzbbJIfPH0UwDEMzov3CHXBXo3M8hmK2QLtefStiVc0JaEN25pLnk6WMYYEC1jtE3xtC3wdZSLUmRWikcZ7JK5n+iT0evsOl9svq32MgDv7FspcnRWy9fX0TWya1mbiENvn/4C/QmhV9IL1mCsQlwzg+g4EsWEQVu4DvIHBkoruB5+XG+NDApJlebqgXUHitXiehrE9O5t8qp83NBKUWKuVVx0LNtAbChd8u4w8R5mDR1XK4pJ6yob4M0d0TAmqTfTypkSlm+8k0y7dqtjMmpvP/vZlHrfZdsASH0qeRf14L0AjLaKK06khUGx3uGJqRO655AN2n0X881BwQZ4USISxgQj1zgr90s3iY33A5BZL271p4wY2EFf2PBG6jwpb6S0c39fACC3fZ3cZ6W7NS0NxMoCONU/KD/GxdYUA7xZHj2zjDEgEsa44YJOaQ+RL83xqkKB9Cr5nWoRVvyTFtd7pEeYs+fcTp7c8rPIu2cNAP6v3dJJ7aUaLgrbvEweNSCL0OLBAZ1fnk0qpOI3RSLlFkQ0Nkbv6RSPqFao03FQRvtsLAXA5U6JLfLftwJw26kCf74ejnZ3SdulHx4DZH8KQAVB5PtI5bCMMSDSJcF0oxg7dAKA1YfkvhixFk7JfU0N3tIlImdgUJeVJp/UDTwpPm8nw8uNpMrnJ7ZxbwLYqWSAVYwBVjEGOPa/HSrDMsYAqxgDrGIMsIoxwCrGAKsYA6xiDLCKMcAqxoD/AdQoCfVuKDHdAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFwElEQVR4nO2cS28bVRTHfzPjOJ7YeT+a0pIQWhEKIlAJFbGqhEQ3rNmxYMmmC5Z8B1jwGRASasUOiUVBiEcFFQK1lIoGQpO+X0ka5+HYsWeGxbnX8Ti+qWNPYxD3v5lk5sz19Zn/eU/iRFEUYbEDbqc38G+FVYwBVjEGWMUYYBVjgFWMAVYxBljFGGAVY0CqWcE33bef5D72FefCs4+VaVox+wEnlYodw1JJLnSgarGmZEDnGON6ADiuA0BUqRAFgVzKZZWIV70GEGkG7QM6phg33QVAWCzuuBZNHATAKYlCnLsPAAhqFeM4SvjJmJk1JQOSZcwuT9HNZEREmQmjQ3JemUv4+1Xcl48BsHakF4Ds5xdia5RPvQrAg/c2mXj3JgDB6mrTe9gLLGMMSJYxhqfkZjK4A/0AhGODABQm+gDIXn0o508e584JH4DcrRAAb/ooAMHsHAD+3CIA/Z+N72CK1yfrBesbADjetlNvBZYxBuxLVIqCcOcHb0poDkbEnxBE5G6L3L2Tchz8zYuvkxeW5M4sVM+5ver+QwcA8FQEi4oSwSxjEkZ7jFERwO3uBswpfFTeIlL5inN3CYDMagGAysINALzjL1Ic6JGf+0Q2/HM+tk6wtLy98Wcm5NygMMbZUkmgyoGYu9HON7OMMSERH+NkhDFay42y2VBFi6iSB8AbeDZ2PcilWZmRpz49riLViRdk/fMXd6wXjEgU8u6viGx/TvawJH6oUmyvfGhPMcpkoiY2Ue8Eg7+uxX7fHEvzxitXADjSI4r54fKayNatlTo4TmFYEkY/nwbAuaeUWdjcwxcww5qSAYmYUrX6DfeehnujowD0nfuD0x9+Hbv27eprDe8Jxofxr4kjDuaUg9YOX1XthPU82xssYwxIhjEtMKV675r4kbBY5JPl1wH46OCvAHhHp4BtVqQmnxbZ2XnCsvJZ9WVIm0zRsIwxIJlwrTttqgNXtfNIUnsn1UVU3mp4b21o/2Dse/WTtCaq/kNjqyz3FAoJ7Hp3WMYYkEwRqZih7dvpSqvzqvQ3sKUW3oExinX+or7tULl7L4ndNoX2FOPWmZBCM4qoR/jUKF+sTwMw4ClT8dzY5zR0rHVmm1QP2JqSAe0xRj9B3Wdto9+6MZXju0fPAXBnQ7p9PcrZulmpukMV2mNQTKlW+A3qtFZgGWNA64xxnG1mNMEQPXatLyZTU5MAbA65/HRRGOPfEr8x6Uk3jmCXpE19dlJM0bCMMaB1xuzBj7jZLG6fdNp0yHWUT6jMXwdgcHaId97/BoDLa4cBmP9ZolRmUXXu9iGxq+553z7pP4YnMiXQbHB7JJo4fTkq12/GZOoH9I+mM3gICz8+/BUAb/nSwdO9Xk/NpoKVfM2H6YjoqoNKKrVfqmW2lm0CljEGJMoYHXm80REAol6Vf3R3wfW4zI5WZ8bh9KAIfbp2CICKL8/NnXle7tHjqRrGuL6vFo5PHjU3Wp0rJVNdq9rIG5FBfeSLKaEGbe5iHv2dTBtNbUScL4rUl0svAbB8TBTTf0nd0yBs60pbj2jb6SbWwpqSAc0zpkG6X62ilcPDV696qFR+fUaGX6lCQOrW7V2XX56JmEzJ0//xb+ncDS/En7pTKhv3pYd9ke7sNSo495JiNC35P0PzjNlF21WHqpnjyWtkWznRe/fy49sQuak8Z1ZnAEjPiUMdnF0Htvsxu+0r6ffzLGMMaCsq6YaUPnor6mWeAxKuuzYlyqTu54nU6xr1rYPw5HEAyr/4nD1zCoCMvFuEUxY/0XbrqYVZk2WMAYkmeMGivOKBOvpX1PkGsrps2OoVf9S7EFHqFx/Vf01Fn0uzyWyshVmTZYwBHf9bguwFmR1lS6VqLuKkJT+K1LzKUQViK032VtExxejwGjx8aLzWSVhTMsAqxgCrGAMc+78dGsMyxgCrGAOsYgywijHAKsYAqxgDrGIMsIoxwCrGgH8A2jQOxDJrNHIAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAF8klEQVR4nO2bz28bVRDHP7trJ85PJ01M3LSkCk1TpQS1lfhZoJxoERLiAoceUQUCpB6QOCHxF3DqrSckrlRCiBMqN1roD1rUHkAqrYKSJm2ApM5P24m9uxzm7cbr5DWO7dSR+r6XzXrfjicz3zczb96z5fu+j8E62I1WYKfCGEYDYxgNjGE0MIbRwBhGA2MYDYxhNDCG0SBW6cA37fe3U4/Hip+885uO2ZGMsWIxrFjFPtsW7EjD7AQ01i0a+MVio1WowTCWBY9pYR7buwcAf2ERAHdhYe2h7UQH+5661qabmUoaVM0Yy3HWU96yNh5c4r3YvqcBKI7fi74ab2Lp3aMAZFPir1xK5PVdLwDQevehDC5ljOdWpf9mMIzRoGrGRNhSxhTLUfNeXf2VlfBZOVNCeYVVigmR89GZHwD4fvoIAHNjwjJ3tBeA+5+kGPrsyqMVDGKP567pt4W4YxijQX3SdbknLGVvb2uZYWmvvPdx1xQAP2eGAVjOihy7INf2wQUKJ54HIH7h+sbCSmNPFRnKMEaDbSnw/MLqpmOc3h4ArLZWQGJPPiU1yLk5qVuWis0ANM1LPJs9JPfN8SItd/6T95S8oNYpTk7V4T8wjNGiYUsCb24egNwr+wFIjN+jf/QfAL4ZfxmAuC0MspvFf4sv5gDo/q4Xd/KGPGtrE3mZubrq1zDD2K0yhZb7JK22dSVpjcsUHO78F4CLX78AQEeTTJineqWwS0w6OL27APDUMqG0JACqStER/ap66wlA46aS8nChQzw7f2KEwx03AeiM5QFI3coCMPNcCwC5K30ADMwv4c2rZYFihtUiY/ylJfk8YIrtVLVsMIzRoGGMcfrTAOR7xLPFhM3Z/l8A+HTyOADxB3MAFI5JPOocU4VerhAuN6zmJgD8nLAsLC59xZKgDbFFGMZo0DDG+HMSI1b7UwDs2p0h44nXL008A8CArBlxFBlsVxhjTf5DmGuCWOLWt/1gGKNBwxjjZjIAvHdECrUPey6RtCVe5GclwziZGQDiy+0AdN+Qez+Xx8tLVrOR+BO2QerUuGpcM1z1S75Ki2FOT5zk8/SFyBD3zhgAvVPTcp+TytdubV0Lqp5cfd1UMgVefdEwxkx8+ZL6SxhzZWofIwMyLYYOPIgODjqCyvve8nJY2PmrhcizDWE6ePXD9jKmtO+q4B87DMDA8QkAZtxlABJNhXDM7Pm9AKSQ/rC3uLhedljyqyVBs/RqwrStCj2/WFhf9FWiesUjnzBsL2NKmOJ0dwMw9pbEkeGY9F5O3vwAgLk7uzjsngIgfe5yRIydSIi4IJ6UMrAgadpJdqoPVJZS6Rzf0WesR8AwRoPtYcwGseXh2wflC0dlKdDTLLHl7orU/enLPv61ZFS5tLQZvD5pSjl/Sz/Xy+XXmKEKO1/VOMFB94BJssdlGFM31MaY8r3qcD9JPBT0Y632NjKHZKzjybXgCav2nI2LIjf/JPvGSERccVrikKX6uV55+7IEXj6/4efVnnmo84ZbtPdhxUX8wuuDrPZJ4OyIidF+vSFT68DFq+GbLT/+DoD9rDxz/7gt4h9hkE1R5drJTCUNamKMFZNpEGywWU2qm6Y87KotkmzKBks8tzgr02v4zNV18oK0mt0nqbdtVoKvn+wQebfv1qLulmAYo0FtMaasnxqkyPD+1SMArCYtrJzaPxqPHg1zDg4BwobZ07LRttKtdg4G90fGpu9LMN5wiVDjPlI5DGM0qIkx646alWWAYpuIz47k6UhKAZa41hUZE8SN2O40liLgSrd4fVXVe+3qrJE7Kr3g2F/3wn2lQIfgsNJGJz6tuIp9FRw2CGAYo0F96hjN/E7cktZC69Eh3jn1GwDfHnoNgK4yEcUH0/j2IACDX0QXkU5KdhL8rOxMuiVLggDas8GWtSWmBDCM0cCq9OfFj+NHFsFhouCIiB8cVavzkdVKfmSxo47MuzOzjVYhhJlKGhjDaGAMo0HFwfdJg2GMBsYwGhjDaGAMo4ExjAbGMBoYw2hgDKOBMYwG/wNWIAshLWlEywAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " for i in range(3):\n", + " fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n", + " _ = ax.imshow(gaussian_samples[i].reshape(28, 28), clim=[0, 1])\n", + " _ = ax.spines[\"bottom\"].set_visible(False)\n", + " _ = ax.spines[\"left\"].set_visible(False)\n", + " _ = ax.set_xticks([])\n", + " _ = ax.set_yticks([])\n", + " plt.savefig(f\"svg/fig2_gauss_mnist_{i}.svg\", bbox_inches=\"tight\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "226df20a-8dab-419c-9278-c0c747f56a42", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAB+UlEQVR4nO3avWtTARSG8aeh1NDSoYh0kYtIqbjUNSB00sVBFHFzc3Ox4OruJm6F/gsFdVQMVMShowgiFMQuLYpgoQpKctvrcCGYti8N1OQc8P1BINwEcng4+b5jVVVV2CGN6AGychjBYQSHERxGcBjBYQSHERxGGB/0jlcbt4c5x0i92l899j7eGMFhBIcRHEZwGCE+TGsBWgu83H5H2S4o20X0RECGMEkN/Dlm2LrVHveK1wCstG7UB9ffh83jjRHCN6bxswPAx26X61M7ADx8UP8MXQR+2PbGCGOD/ksw7O9KZbvgxcVnALz5PQHA48tX6tu+fP2nj+XvSieQMsxis8NiswPNU/UlQMowGTiM4DCCwwhpwmy/PRs9Qp80YbJJE2Z6M9fZKGnCZJMmzPSdregR+qQJk43DCA4jOIzgMILDCA4jOIwQ/i/B+PlzANwt1mIHOcAbI4RvzO6lWQBuTn3vHfvQKesr3TJiJCAwTGNyEoDTS5uHbrv1/D4Ac1vroxypj59KQtzGzJ4BYHXuae/Y2q8mABeWvwGwN/qxerwxQviL798efb4GwMTGp+BJvDFS2MZUuz8AeLIzD8DSzEbUKEfyxghpTgMZJZ8GcgIOIziM4DCCwwgOIwz8dv2/8cYIDiM4jOAwgsMIDiM4jOAwgsMIfwC16V1XQIICrgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAACKElEQVR4nO3bv2sTcRzG8XfORutUA6atOEgdHERQikt3i4s4KG46+A9IJ/8B/QecdBEEySS4KjgrXVyKoLRUpRoQhKJE8AfmzuGaa095NIv3+YrPa8mXy0E+PHmO7xFyraIoCuwXWfQAqXIwgoMRHIzgYAQHIzgYwcEIDkaYGPfExezC35yjUY/ye388x40RHIzgYAQHIzgYYexdqQmnng0A6E6Ur7eunQNgqrfc+CxujJBMY7ITR1nq3AUgJwfg+nz53lQvYJ7mP/LfkE4w62+iJ6hJJ5jEJBNMPhhEj1CTTDCpcTBCMts1QEarWgHs/hD3vbkxQjKN2by8QM5TYPsGb+7OBgDfA+ZxY4RkGrN45XG13hx+LRd5HjSNGyOFN6ZYOA7Apc5NYA8Ap29cBeBA/0nUWG6MEt6Yz7OTABxut6tj+9Yj9qG68GBGMjLarV3RY1R8KQnhjXl7ZgiUN3XPv5Xb9N53XyJHAtwYKbwx3dmP1frBp2PlYnklaJptbowQ1phsstymj3TeR43wW26MENeYmS4Atw/dr471Xp0EYD+rITPt5MYIYY15sXQQKO94R+an+wBshExUF7ddbz3aMfq1DqB/cWZr9bL5eX7iS0kIv8HbabgW35QRN0ZIJpjza2ejR6hJJpjUJBPM64dz0SPUJBNMalrjPizqv8wb4GAkByM4GMHBCA5GGHu7/t+4MYKDERyM4GAEByM4GMHBCA5GcDDCD/vjW5gaPr9sAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEYAAABGCAYAAABxLuKEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAACqklEQVR4nO2bT0hUURSHv/mXxUxWOg6STIipIUoSRkEEujEKhiAiqF0E0cqgRbRpWRC0qIXCoEEQRJFBuCpoERQZQdiQxeRQWYFBDTIRxrSYZlpMMzX4Do2Ld9+Fzrc858E7/Pjuu4/3x1cqlUooy/B7PYCtaDACGoyABiOgwQhoMAIajIAGI6DBCATrPXDYf8jNOYxyvzj5z2PUGAENRkCDEdBgBDQYgbp3JS/JXB0A4M2eiWpt94vyLtm4760r51RjBKw2JrB+HQADnR8AKPLnYeNiKgZAI2qMUaw2hlgUgBsdt2rKd7430TX+CYCCS6dWYwSsNMa3rReA7itzjv2JYwfwz6dcnUGNEbDSmKXNEQAutj517IfSH/np8gxWBeMLrQJg06mMY39otnxTt3ZpwfVZdCkJWGXM3Fg/AJn2ZE19+NVBACJ73wFg4p2yGiNglTFHdz52rH95tBGAOO+NzaLGCFhjTKB3C31r7tbUHuRXA9A+mgZwfYv+GzVGwHNjAj1dAAzenGF/OAfAfOEHAGcujQAQy00bn8vzYNInNwAw1fS6Wrv+dQcAsTHzgVTQpSTgmTGBlhYAZhKXf1cayBXLS2jy9iAAcdQY6zBujD8cBqDn3iIAEX9DtTeUPA1A/Lx3plRQYwSMG5M9vBWAC62jy3ptD/OmxxFRYwSMGVPZhfqPzzr2dz0/QvOTcs+GT9XVGAFjxmQTnQBMxWuvLde+tQEQPRukWHDrLdHKMRJMINpM34mXjr1z0wkAulPPTIxSN7qUBIwYk9/ewXg86dgLZkMmRlgxaoyA548dwgs+r0dwRI0R8N6Yz0WvR3BEjRHw1fuzqH4yrwAajIgGI6DBCGgwAhqMQN3b9f+GGiOgwQhoMAIajIAGI6DBCGgwAhqMgAYj8AtnBX0OD0TeCgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with mpl.rc_context(fname=\"../../../matplotlibrc\"):\n", + " for i in range(3):\n", + " fig, ax = plt.subplots(1, 1, figsize=(0.65, 0.65))\n", + " _ = ax.imshow(X_train[i].reshape(28, 28), clim=[0, 1])\n", + " _ = ax.spines[\"bottom\"].set_visible(False)\n", + " _ = ax.spines[\"left\"].set_visible(False)\n", + " _ = ax.set_xticks([])\n", + " _ = ax.set_yticks([])\n", + " plt.savefig(f\"svg/fig2_mnist_{i}.svg\", bbox_inches=\"tight\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "f1441a82-f838-410c-bb15-bb360514e339", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def svg(img):\n", + " IPd.display(IPd.HTML(''.format(img, time.time())))\n", + "\n", + "# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.\n", + "# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964\n", + "svg_scale = 1.25 # set this to 1.25 for Inkscape, 1.0 otherwise\n", + "\n", + "# Panel letters in Helvetica Neue, 12pt, Medium\n", + "kwargs_caption = {'size': '10pt', 'font': 'Arial', 'weight': '800'}\n", + "kwargs_text = {'size': '7.75pt', 'font': 'Arial'}\n", + "\n", + "f = Figure(\n", + " \"16.6cm\",\n", + " \"5.6cm\",\n", + " Panel(SVG(\"svg/fig2_panel_a.svg\").scale(svg_scale)).move(102, 26.4),\n", + " Panel(Text(\"Data\", 23, 32.0, **kwargs_text)),\n", + " Panel(SVG(\"svg/fig2_panel_b.svg\").scale(svg_scale)).move(102, 111.5),\n", + " Panel(Text(\"Model\", 23, 117.0, **kwargs_text)),\n", + " Panel(SVG(\"svg/fig2_illustration.svg\").scale(svg_scale).move(0, 5)).move(-3, 10),\n", + " Panel(Text(\"a\", 5, 12.0, **kwargs_caption), Text(\"Failure of C2ST to discriminate classes\", 45, 12.0, **kwargs_text)).move(-4, 0),\n", + " \n", + " Panel(Text(\"b\", -35, -8.5, **kwargs_caption), Text(\"C2STs for high-D data\", -12, -8.5, **kwargs_text)).move(342, 20.5),\n", + " Panel(SVG(\"svg/fig2_panel_c1.svg\").scale(svg_scale)).move(305, 27.6),\n", + " Panel(SVG(\"svg/fig2_panel_c2.svg\").scale(svg_scale)).move(305, 111.8),\n", + "\n", + " Panel(Text(\"c\", -25, 12, **kwargs_caption), Text(\"High C2ST on MNIST\", 3, 12, **kwargs_text)).move(502, 0.0),\n", + " Panel(SVG(\"svg/fig2_mnist_0.svg\")).move(485, 25.5),\n", + " Panel(SVG(\"svg/fig2_mnist_1.svg\")).move(532.5, 25.5),\n", + " Panel(SVG(\"svg/fig2_mnist_2.svg\")).move(580, 25.5),\n", + " Panel(Text(\"Data\", 491, 31, **kwargs_text)),\n", + "\n", + " Panel(SVG(\"svg/fig2_gauss_mnist_0.svg\")).move(485, 78),\n", + " Panel(SVG(\"svg/fig2_gauss_mnist_1.svg\")).move(532.5, 78),\n", + " Panel(SVG(\"svg/fig2_gauss_mnist_2.svg\")).move(580, 78),\n", + " Panel(Text(\"Gaussian: C2ST=1.0\", 491, 83.5, **kwargs_text)),\n", + "\n", + " Panel(SVG(\"svg/fig2_gmm_mnist_0.svg\")).move(485, 130.5),\n", + " Panel(SVG(\"svg/fig2_gmm_mnist_1.svg\")).move(532.5, 130.5),\n", + " Panel(SVG(\"svg/fig2_gmm_mnist_2.svg\")).move(580, 130.5),\n", + " Panel(Text(\"MoG: C2ST=1.0\", 491, 135.5, **kwargs_text)),\n", + ")\n", + "\n", + "!mkdir -p fig\n", + "f.save(\"fig/fig2.svg\")\n", + "svg(\"fig/fig2.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6c2ca94-64fc-4f2c-bb1a-c3e054f6f7ed", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 3dc7fc1..66ff6d7 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ