From 240ffd684e1bc06746f86ec1f1b689829ead6fb1 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 6 Jun 2024 13:33:43 +0200 Subject: [PATCH] new models --- docs/notebooks/fid/additional_models.ipynb | 842 ++++++++++++++++++-- docs/notebooks/fid/model_table.ipynb | 245 ++++++ docs/notebooks/fid/results_c2st_knn.npy | Bin 0 -> 1232 bytes docs/notebooks/fid/results_c2st_nn.npy | Bin 0 -> 1232 bytes docs/notebooks/fid/results_fid.npy | Bin 0 -> 1212 bytes docs/notebooks/fid/results_mmd_lin.npy | Bin 0 -> 1212 bytes docs/notebooks/fid/results_mmd_poly_kid.npy | Bin 0 -> 1212 bytes docs/notebooks/fid/results_mmd_rbf64.npy | Bin 0 -> 1212 bytes docs/notebooks/fid/results_sw.npy | Bin 0 -> 1212 bytes labproject/data.py | 6 +- 10 files changed, 1025 insertions(+), 68 deletions(-) create mode 100644 docs/notebooks/fid/model_table.ipynb create mode 100644 docs/notebooks/fid/results_c2st_knn.npy create mode 100644 docs/notebooks/fid/results_c2st_nn.npy create mode 100644 docs/notebooks/fid/results_fid.npy create mode 100644 docs/notebooks/fid/results_mmd_lin.npy create mode 100644 docs/notebooks/fid/results_mmd_poly_kid.npy create mode 100644 docs/notebooks/fid/results_mmd_rbf64.npy create mode 100644 docs/notebooks/fid/results_sw.npy diff --git a/docs/notebooks/fid/additional_models.ipynb b/docs/notebooks/fid/additional_models.ipynb index 35ae2aa..43041b1 100644 --- a/docs/notebooks/fid/additional_models.ipynb +++ b/docs/notebooks/fid/additional_models.ipynb @@ -2,22 +2,36 @@ "cells": [ { "cell_type": "code", - "execution_count": 95, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import torch\n", "from labproject.data import get_dataset, DATASETS\n", "from labproject.metrics import METRICS\n", + "import numpy as np\n", "\n", "from labproject.metrics.utils import get_metric\n", "\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "\n", + "torch.manual_seed(0)" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -30,7 +44,7 @@ " 'cifar10_test': ,\n", " 'imagenet_real_embeddings': ,\n", " 'imagenet_uncond_embeddings': ,\n", - " 'imagenet_unconditional_model_embedding': ,\n", + " 'imagenet_unconditional_model_embedding': ,\n", " 'imagenet_test_embedding': ,\n", " 'imagenet_validation_embedding': ,\n", " 'imagenet_conditional_model': ,\n", @@ -47,7 +61,7 @@ " 'imagenet_cs100_embedding': }" ] }, - "execution_count": 52, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -58,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -70,6 +84,7 @@ " 'mmd_polynomial': ,\n", " 'mmd_linear_naive': ,\n", " 'mmd_linear': ,\n", + " 'mmd_energy': ,\n", " 'c2st_nn': torch.Tensor>,\n", " 'c2st_rf': torch.Tensor>,\n", " 'c2st_knn': torch.Tensor>,\n", @@ -80,7 +95,7 @@ " 'wasserstein_sinkhorn': torch.Tensor>}" ] }, - "execution_count": 96, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -91,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -101,16 +116,17 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "datasets = [\"imagenet_unconditional_model_embedding\",\"imagenet_cs1_embedding\", \"imagenet_cs10_embedding\", \"imagenet_biggan_embedding\", \"imagenet_sdv4_embedding\", \"imagenet_sdv5_embedding\", \"imagenet_vqdm_embedding\", \"imagenet_wukong_embedding\", \"imagenet_adm_embedding\", \"imagenet_midjourney_embedding\"]" + "datasets = [\"imagenet_unconditional_model_embedding\",\"imagenet_cs1_embedding\", \"imagenet_cs10_embedding\", \"imagenet_biggan_embedding\", \"imagenet_sdv4_embedding\", \"imagenet_sdv5_embedding\", \"imagenet_vqdm_embedding\", \"imagenet_wukong_embedding\", \"imagenet_adm_embedding\", \"imagenet_midjourney_embedding\"]\n", + "metrics = [\"wasserstein_gauss_squared\", \"sliced_wasserstein\"]" ] }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -130,124 +146,783 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ - "results = {}" + "results_fid = {}" ] }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ + "torch.manual_seed(0)\n", + "metric = metrics[0]\n", + "metric_fn = get_metric(metric)\n", "for dname in datasets:\n", - " data = get_dataset(dname)(20_000, 2048, permute=True)\n", - " " + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn)\n", + " metric_values.append(m)\n", + " results_fid[dname] = np.array(metric_values)" ] }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array([190., 351., 418., 528., 509., 581., 534., 547., 529., 488., 459.,\n", - " 437., 401., 343., 318., 306., 272., 235., 217., 221., 175., 176.,\n", - " 158., 124., 151., 121., 105., 91., 92., 76., 66., 77., 72.,\n", - " 60., 42., 38., 42., 37., 42., 32., 23., 21., 29., 26.,\n", - " 23., 22., 25., 18., 17., 11., 10., 13., 17., 11., 4.,\n", - " 8., 6., 3., 5., 6., 1., 3., 4., 2., 1., 3.,\n", - " 3., 1., 1., 3., 0., 1., 2., 5., 1., 1., 0.,\n", - " 0., 2., 0., 3., 0., 0., 1., 0., 0., 0., 0.,\n", - " 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n", - " 1.]),\n", - " array([0. , 0.0224842 , 0.0449684 , 0.0674526 , 0.0899368 ,\n", - " 0.112421 , 0.1349052 , 0.1573894 , 0.1798736 , 0.2023578 ,\n", - " 0.224842 , 0.2473262 , 0.26981041, 0.29229459, 0.3147788 ,\n", - " 0.33726299, 0.3597472 , 0.38223141, 0.4047156 , 0.42719981,\n", - " 0.44968399, 0.47216821, 0.49465239, 0.51713657, 0.53962082,\n", - " 0.562105 , 0.58458918, 0.60707343, 0.62955761, 0.65204179,\n", - " 0.67452598, 0.69701022, 0.7194944 , 0.74197859, 0.76446283,\n", - " 0.78694701, 0.8094312 , 0.83191538, 0.85439962, 0.8768838 ,\n", - " 0.89936799, 0.92185217, 0.94433641, 0.9668206 , 0.98930478,\n", - " 1.01178896, 1.03427315, 1.05675745, 1.07924163, 1.10172582,\n", - " 1.12421 , 1.14669418, 1.16917837, 1.19166255, 1.21414685,\n", - " 1.23663104, 1.25911522, 1.2815994 , 1.30408359, 1.32656777,\n", - " 1.34905195, 1.37153625, 1.39402044, 1.41650462, 1.4389888 ,\n", - " 1.46147299, 1.48395717, 1.50644135, 1.52892566, 1.55140984,\n", - " 1.57389402, 1.59637821, 1.61886239, 1.64134657, 1.66383076,\n", - " 1.68631506, 1.70879924, 1.73128343, 1.75376761, 1.77625179,\n", - " 1.79873598, 1.82122016, 1.84370434, 1.86618865, 1.88867283,\n", - " 1.91115701, 1.9336412 , 1.95612538, 1.97860956, 2.00109386,\n", - " 2.02357793, 2.04606223, 2.0685463 , 2.0910306 , 2.1135149 ,\n", - " 2.13599896, 2.15848327, 2.18096733, 2.20345163, 2.2259357 ,\n", - " 2.24842 ]),\n", - " )" + "{'imagenet_unconditional_model_embedding': array([6.2195807, 6.1363096, 6.1779237, 6.1323247, 6.1145873],\n", + " dtype=float32),\n", + " 'imagenet_cs1_embedding': array([6.430565 , 6.3383527, 6.3908176, 6.392403 , 6.3483315],\n", + " dtype=float32),\n", + " 'imagenet_cs10_embedding': array([6.9664445, 6.934556 , 6.985526 , 6.980446 , 6.945889 ],\n", + " dtype=float32),\n", + " 'imagenet_biggan_embedding': array([12.701342, 12.732329, 12.704566, 12.742165, 12.572661],\n", + " dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([17.149105, 17.233383, 17.262657, 17.13218 , 17.054472],\n", + " dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([17.27084 , 17.254538, 17.438782, 17.217953, 17.170486],\n", + " dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([11.207735, 11.090878, 11.292017, 11.206196, 11.167131],\n", + " dtype=float32),\n", + " 'imagenet_wukong_embedding': array([18.644157, 18.674633, 18.625652, 18.643795, 18.433655],\n", + " dtype=float32),\n", + " 'imagenet_adm_embedding': array([12.630351 , 12.518478 , 12.667346 , 12.5699835, 12.462262 ],\n", + " dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([17.360374, 17.406929, 17.517849, 17.314253, 17.256212],\n", + " dtype=float32)}" ] }, - "execution_count": 83, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" - }, + } + ], + "source": [ + "results_fid" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_fid.npy\", results_fid)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "results_sw = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"sliced_wasserstein\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets[-1:]:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, num_projections=5000)\n", + " metric_values.append(m)\n", + " results_sw[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'imagenet_unconditional_model_embedding': array([0.02358142, 0.02302841, 0.02342947, 0.02363985, 0.02274173],\n", + " dtype=float32),\n", + " 'imagenet_cs1_embedding': array([0.02153669, 0.02145257, 0.02175897, 0.02185666, 0.02145214],\n", + " dtype=float32),\n", + " 'imagenet_cs10_embedding': array([0.02553654, 0.02503867, 0.02539671, 0.0255399 , 0.02508791],\n", + " dtype=float32),\n", + " 'imagenet_biggan_embedding': array([0.05029042, 0.05024597, 0.05106073, 0.05075597, 0.05070167],\n", + " dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([0.05571224, 0.05621962, 0.05671528, 0.056228 , 0.05603858],\n", + " dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([0.05642236, 0.05595 , 0.0562194 , 0.05640132, 0.05647483],\n", + " dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([0.04172998, 0.04103179, 0.04130031, 0.04112783, 0.04078227],\n", + " dtype=float32),\n", + " 'imagenet_wukong_embedding': array([0.05560957, 0.05517614, 0.05603101, 0.05579789, 0.05510735],\n", + " dtype=float32),\n", + " 'imagenet_adm_embedding': array([0.05062867, 0.05020323, 0.05051676, 0.05036184, 0.05050731],\n", + " dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([0.0481902 , 0.04775954, 0.04859642, 0.0480745 , 0.04758868],\n", + " dtype=float32)}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_sw" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_sw.npy\", results_sw)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "results_mmd_rbf64 = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n", + "Starting imagenet_cs10_embedding\n", + "Starting imagenet_biggan_embedding\n", + "Starting imagenet_sdv4_embedding\n", + "Starting imagenet_sdv5_embedding\n", + "Starting imagenet_vqdm_embedding\n", + "Starting imagenet_wukong_embedding\n", + "Starting imagenet_adm_embedding\n", + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"mmd_rbf\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, bandwidth=64.0)\n", + " metric_values.append(m)\n", + " results_mmd_rbf64[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'imagenet_unconditional_model_embedding': array([7.1763992e-05, 6.2704086e-05, 7.2360039e-05, 6.4253807e-05,\n", + " 6.0319901e-05], dtype=float32),\n", + " 'imagenet_cs1_embedding': array([7.1883202e-05, 5.7458878e-05, 5.9604645e-05, 6.4969063e-05,\n", + " 5.7697296e-05], dtype=float32),\n", + " 'imagenet_cs10_embedding': array([9.0479851e-05, 8.4877014e-05, 8.2612038e-05, 8.7141991e-05,\n", + " 8.1181526e-05], dtype=float32),\n", + " 'imagenet_biggan_embedding': array([0.00019705, 0.00018573, 0.00018668, 0.00018907, 0.00017405],\n", + " dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([0.00021684, 0.00020289, 0.00020945, 0.00021148, 0.00020015],\n", + " dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([0.00021875, 0.00020254, 0.00021124, 0.00021267, 0.00020623],\n", + " dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([0.00015748, 0.0001334 , 0.0001483 , 0.00015378, 0.00014079],\n", + " dtype=float32),\n", + " 'imagenet_wukong_embedding': array([0.0002085 , 0.00019038, 0.00019991, 0.00020087, 0.00018454],\n", + " dtype=float32),\n", + " 'imagenet_adm_embedding': array([0.00020254, 0.00019121, 0.00019884, 0.00019133, 0.00017893],\n", + " dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([0.00018728, 0.00016236, 0.00017667, 0.00017893, 0.00017178],\n", + " dtype=float32)}" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_mmd_rbf64" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_mmd_rbf64.npy\", results_mmd_rbf64)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "results_mmd_lin = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n", + "Starting imagenet_cs10_embedding\n", + "Starting imagenet_biggan_embedding\n", + "Starting imagenet_sdv4_embedding\n", + "Starting imagenet_sdv5_embedding\n", + "Starting imagenet_vqdm_embedding\n", + "Starting imagenet_wukong_embedding\n", + "Starting imagenet_adm_embedding\n", + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"mmd_linear\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn)\n", + " metric_values.append(m)\n", + " results_mmd_lin[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_mmd_lin.npy\", results_mmd_lin)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'imagenet_unconditional_model_embedding': array([0.27161348, 0.23375021, 0.27461752, 0.23648134, 0.22176446],\n", + " dtype=float32),\n", + " 'imagenet_cs1_embedding': array([0.2784213 , 0.2160877 , 0.2256283 , 0.24824372, 0.21793067],\n", + " dtype=float32),\n", + " 'imagenet_cs10_embedding': array([0.3461548 , 0.32180998, 0.31091288, 0.33070827, 0.30549592],\n", + " dtype=float32),\n", + " 'imagenet_biggan_embedding': array([0.6735612, 0.6250146, 0.6302967, 0.6348198, 0.574922 ],\n", + " dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([0.7234071 , 0.6639475 , 0.6890489 , 0.69806087, 0.6532722 ],\n", + " dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([0.73029166, 0.66163504, 0.6962174 , 0.70317215, 0.6768427 ],\n", + " dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([0.55959 , 0.45710978, 0.51890296, 0.5427842 , 0.4893559 ],\n", + " dtype=float32),\n", + " 'imagenet_wukong_embedding': array([0.6912052 , 0.61376953, 0.6529302 , 0.6583701 , 0.5910167 ],\n", + " dtype=float32),\n", + " 'imagenet_adm_embedding': array([0.6968005, 0.6499678, 0.6796049, 0.6463373, 0.5955869],\n", + " dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([0.6486083 , 0.54306364, 0.6036365 , 0.61384624, 0.583929 ],\n", + " dtype=float32)}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_mmd_lin" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "results_mmd_poly_kid = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n", + "Starting imagenet_cs10_embedding\n", + "Starting imagenet_biggan_embedding\n", + "Starting imagenet_sdv4_embedding\n", + "Starting imagenet_sdv5_embedding\n", + "Starting imagenet_vqdm_embedding\n", + "Starting imagenet_wukong_embedding\n", + "Starting imagenet_adm_embedding\n", + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"mmd_polynomial\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, degree=3, bias=1.)\n", + " metric_values.append(m)\n", + " results_mmd_poly_kid[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAllklEQVR4nO3df3BU1f3/8VdCfsmP3ZjU7JIxQawiRFEUJFmxrbSBiNGRMbbFYTDtUOkwCS2komaGSquOaaltLDaIOBZsLUPLdLQ1FmgMEVrYgI0wE4MNKnyaWNjEgtkFWhIg9/uH31zdJEA2bLInm+dj5k65957dfV8vy7567rnnxliWZQkAAMAgsZEuAAAAoDsCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOHGRLqA/Ojs7deTIEY0ZM0YxMTGRLgcAAPSBZVk6ceKE0tPTFRt74T6SIRlQjhw5ooyMjEiXAQAA+qG5uVlXXnnlBdsMyYAyZswYSZ8eoMPhiHA1AACgLwKBgDIyMuzf8QsJOaD8+9//1qOPPqotW7bov//9r6655hqtX79e06ZNk/Rp983KlSv14osvqq2tTTNmzNDzzz+va6+91n6P48ePa8mSJXr99dcVGxurgoIC/fKXv9To0aP7VEPXZR2Hw0FAAQBgiOnL8IyQBsl+8sknmjFjhuLj47VlyxYdOHBAP//5z3X55ZfbbVatWqXVq1dr7dq12rNnj0aNGqW8vDydPn3abjN//nw1NDSoqqpKlZWV2rlzpxYtWhRKKQAAIIrFhPI048cee0y7du3S3/72t173W5al9PR0/eAHP9DDDz8sSfL7/XK5XNqwYYPmzZun9957T1lZWXr77bftXpetW7fqrrvu0kcffaT09PSL1hEIBOR0OuX3++lBAQBgiAjl9zukHpQ///nPmjZtmr7+9a8rLS1NN998s1588UV7/+HDh+Xz+ZSbm2tvczqdys7OltfrlSR5vV4lJyfb4USScnNzFRsbqz179vT6ue3t7QoEAkELAACIXiEFlEOHDtnjSbZt26bFixfre9/7nl5++WVJks/nkyS5XK6g17lcLnufz+dTWlpa0P64uDilpKTYbborKyuT0+m0F+7gAQAguoUUUDo7O3XLLbfo6aef1s0336xFixbpoYce0tq1aweqPklSaWmp/H6/vTQ3Nw/o5wEAgMgKKaCMHTtWWVlZQdsmTZqkpqYmSZLb7ZYktbS0BLVpaWmx97ndbrW2tgbtP3v2rI4fP2636S4xMdG+Y4c7dwAAiH4hBZQZM2aosbExaNvBgwc1btw4SdL48ePldrtVXV1t7w8EAtqzZ488Ho8kyePxqK2tTXV1dXab7du3q7OzU9nZ2f0+EAAAED1Cmgdl2bJluu222/T000/rG9/4hvbu3at169Zp3bp1kj69r3np0qV66qmndO2112r8+PH64Q9/qPT0dM2dO1fSpz0ud955p31p6MyZMyouLta8efP6dAcPAACIfiHdZixJlZWVKi0t1fvvv6/x48erpKREDz30kL2/a6K2devWqa2tTbfffrvWrFmjCRMm2G2OHz+u4uLioInaVq9e3eeJ2rjNGACAoSeU3++QA4oJCCgAAAw9AzYPCgAAwGAgoAAAAOMQUAAAgHEIKAAAwDgh3WaMS1RT1nPbzNLBrwMAAMPRgwIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDgEFAAAYBwCCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBx4iJdALqpKeu5bWbp4NcBAEAE0YMCAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAO86BEWm/zngAAMMzRgwIAAIxDQAEAAMbhEk+4DOQU9d3fm6nvAQBRjh4UAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHZ/EMpN6ez9NP3kPH7D/Xnj2oZbMmhO29AQAwTUg9KD/60Y8UExMTtEycONHef/r0aRUVFSk1NVWjR49WQUGBWlpagt6jqalJ+fn5GjlypNLS0rR8+XKdPXs2PEcDAACiQsg9KNdff73efPPNz94g7rO3WLZsmd544w1t3rxZTqdTxcXFuu+++7Rr1y5J0rlz55Sfny+3263du3fr6NGjevDBBxUfH6+nn346DIcDAACiQcgBJS4uTm63u8d2v9+vl156SRs3btRXv/pVSdL69es1adIk1dbWKicnR3/961914MABvfnmm3K5XJoyZYqefPJJPfroo/rRj36khISESz8iAAAw5IU8SPb9999Xenq6rr76as2fP19NTU2SpLq6Op05c0a5ubl224kTJyozM1Ner1eS5PV6NXnyZLlcLrtNXl6eAoGAGhoazvuZ7e3tCgQCQUu08R46FrQAADCchRRQsrOztWHDBm3dulXPP/+8Dh8+rC996Us6ceKEfD6fEhISlJycHPQal8sln88nSfL5fEHhpGt/177zKSsrk9PptJeMjIxQygYAAENMSJd45syZY//5xhtvVHZ2tsaNG6c//OEPuuyyy8JeXJfS0lKVlJTY64FAgJACAEAUu6R5UJKTkzVhwgR98MEHcrvd6ujoUFtbW1CblpYWe8yK2+3ucVdP13pv41q6JCYmyuFwBC0AACB6XVJAOXnypD788EONHTtWU6dOVXx8vKqrq+39jY2NampqksfjkSR5PB7V19ertbXVblNVVSWHw6GsrKxLKQUAAESRkC7xPPzww7rnnns0btw4HTlyRCtXrtSIESP0wAMPyOl0auHChSopKVFKSoocDoeWLFkij8ejnJwcSdLs2bOVlZWlBQsWaNWqVfL5fFqxYoWKioqUmJg4IAcIAACGnpACykcffaQHHnhAx44d0xVXXKHbb79dtbW1uuKKKyRJ5eXlio2NVUFBgdrb25WXl6c1a9bYrx8xYoQqKyu1ePFieTwejRo1SoWFhXriiSfCe1SG6e2uHM/VqSG/BgCA4SKkgLJp06YL7k9KSlJFRYUqKirO22bcuHH6y1/+EsrHAgCAYYaHBQIAAOPwsMAhKKdpnVTT7RLRzNLIFAMAwACgBwUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDjcZjwAmAUWAIBLQw8KAAAwDj0oEUIvCwAA50cPCgAAMA49KENU9x4Yz8wIFQIAwACgBwUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHG4iyda1JQFr88sjUwdAACEAT0oAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIzD04yjVfenG0s84RgAMGQQUIaR8qqDQevLZk2IUCUAAFwYl3gAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHGYSba/eptKHgAAhAU9KAAAwDgEFAAAYBwu8UQp76FjPTdmDn4dAAD0Bz0oAADAOAQUAABgHAIKAAAwziUFlJ/85CeKiYnR0qVL7W2nT59WUVGRUlNTNXr0aBUUFKilpSXodU1NTcrPz9fIkSOVlpam5cuX6+zZs5dSyrDnPXQsaAEAYCjrd0B5++239cILL+jGG28M2r5s2TK9/vrr2rx5s3bs2KEjR47ovvvus/efO3dO+fn56ujo0O7du/Xyyy9rw4YNevzxx/t/FAAAIKr0K6CcPHlS8+fP14svvqjLL7/c3u73+/XSSy/pF7/4hb761a9q6tSpWr9+vXbv3q3a2lpJ0l//+lcdOHBAr7zyiqZMmaI5c+boySefVEVFhTo6OsJzVAAAYEjrV0ApKipSfn6+cnNzg7bX1dXpzJkzQdsnTpyozMxMeb1eSZLX69XkyZPlcrnsNnl5eQoEAmpoaOj189rb2xUIBIIWhC6naV3QAgCAqUKeB2XTpk1655139Pbbb/fY5/P5lJCQoOTk5KDtLpdLPp/PbvP5cNK1v2tfb8rKyvTjH/841FIBAMAQFVIPSnNzs77//e/rd7/7nZKSkgaqph5KS0vl9/vtpbm5edA+GwAADL6QAkpdXZ1aW1t1yy23KC4uTnFxcdqxY4dWr16tuLg4uVwudXR0qK2tLeh1LS0tcrvdkiS3293jrp6u9a423SUmJsrhcAQtAAAgeoUUUL72ta+pvr5e+/fvt5dp06Zp/vz59p/j4+NVXV1tv6axsVFNTU3yeDySJI/Ho/r6erW2ttptqqqq5HA4lJWVFabDAgAAQ1lIY1DGjBmjG264IWjbqFGjlJqaam9fuHChSkpKlJKSIofDoSVLlsjj8SgnJ0eSNHv2bGVlZWnBggVatWqVfD6fVqxYoaKiIiUmJobpsAAAwFAW9ocFlpeXKzY2VgUFBWpvb1deXp7WrFlj7x8xYoQqKyu1ePFieTwejRo1SoWFhXriiSfCXQoAABiiYizLsiJdRKgCgYCcTqf8fn/kxqPUlJ1311CZydWz8JlIlwAAGEZC+f3mWTwAAMA4BBQAAGAcAgoAADBO2AfJDkdDZcwJAABDBT0oAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHhwXCVl51MGh92awJEaoEADDc0YMCAACMQ0ABAADGIaAAAADjMAZlGOs+5gQAAFPQgwIAAIxDQAEAAMbhEk9f1JRFugIAAIYVelAAAIBxCCgAAMA4BBQAAGAcxqD0g/fQsUiXEBY5Tet6bKvNXBSBSgAACEYPCgAAMA4BBQAAGIeAAgAAjENAAQAAxiGgAAAA4xBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4cZEuAGbJaVr32UpNqjSzNHLFAACGLXpQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYJ6SA8vzzz+vGG2+Uw+GQw+GQx+PRli1b7P2nT59WUVGRUlNTNXr0aBUUFKilpSXoPZqampSfn6+RI0cqLS1Ny5cv19mzZ8NzNAAAICqEdJvxlVdeqZ/85Ce69tprZVmWXn75Zd17773at2+frr/+ei1btkxvvPGGNm/eLKfTqeLiYt13333atWuXJOncuXPKz8+X2+3W7t27dfToUT344IOKj4/X008/PSAHiEtUU9ZzG7ceAwAGWIxlWdalvEFKSop+9rOf6f7779cVV1yhjRs36v7775ck/fOf/9SkSZPk9XqVk5OjLVu26O6779aRI0fkcrkkSWvXrtWjjz6qjz/+WAkJCX36zEAgIKfTKb/fL4fDcSnl9023H2nvoWMD/5kG8Fyd2vsOAgoAoB9C+f3u9xiUc+fOadOmTTp16pQ8Ho/q6up05swZ5ebm2m0mTpyozMxMeb1eSZLX69XkyZPtcCJJeXl5CgQCamho6G8pAAAgyoQ8k2x9fb08Ho9Onz6t0aNH69VXX1VWVpb279+vhIQEJScnB7V3uVzy+XySJJ/PFxROuvZ37Tuf9vZ2tbe32+uBQCDUsgEAwBAScg/Kddddp/3792vPnj1avHixCgsLdeDAgYGozVZWVian02kvGRkZA/p5AAAgskIOKAkJCbrmmms0depUlZWV6aabbtIvf/lLud1udXR0qK2tLah9S0uL3G63JMntdve4q6drvatNb0pLS+X3++2lubk51LIBAMAQcsnzoHR2dqq9vV1Tp05VfHy8qqur7X2NjY1qamqSx+ORJHk8HtXX16u1tdVuU1VVJYfDoaysrPN+RmJion1rc9cCAACiV0hjUEpLSzVnzhxlZmbqxIkT2rhxo9566y1t27ZNTqdTCxcuVElJiVJSUuRwOLRkyRJ5PB7l5ORIkmbPnq2srCwtWLBAq1atks/n04oVK1RUVKTExMQBOUAAADD0hBRQWltb9eCDD+ro0aNyOp268cYbtW3bNs2aNUuSVF5ertjYWBUUFKi9vV15eXlas2aN/foRI0aosrJSixcvlsfj0ahRo1RYWKgnnngivEcFAACGtEueByUSmAclMux5UZgHBQDQD6H8fod8mzGGr65gVnv2oCRp2awJkSwHABDFeFggAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOM8kiZDlN6yRJ3pc+Xa/NXGTvY3ZZAEA40IMCAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHEIKAAAwDg8zbgPvIeORboEAACGFXpQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjMFEbLllO07rPVmpSP/3fmaWRKQYAEBUIKAirrll3a88elCQtmzUhkuUAAIYoLvEAAADjEFAAAIBxCCgAAMA4BBQAAGAcBsliUJVXHQxaZxAtAKA39KAAAADj0IOCwVFTJknKaTpmb6rNXBSpagAAhiOgYEB1XdL5fDABAOBiuMQDAACMQw8KBkTQ9PcAAISIHhQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYJKaCUlZXp1ltv1ZgxY5SWlqa5c+eqsbExqM3p06dVVFSk1NRUjR49WgUFBWppaQlq09TUpPz8fI0cOVJpaWlavny5zp49e+lHAwAAokJIAWXHjh0qKipSbW2tqqqqdObMGc2ePVunTp2y2yxbtkyvv/66Nm/erB07dujIkSO677777P3nzp1Tfn6+Ojo6tHv3br388svasGGDHn/88fAdFQAAGNJiLMuy+vvijz/+WGlpadqxY4e+/OUvy+/364orrtDGjRt1//33S5L++c9/atKkSfJ6vcrJydGWLVt0991368iRI3K5XJKktWvX6tFHH9XHH3+shISEi35uIBCQ0+mU3++Xw+Hob/l95n3p4QH/jOGoNnMRDwsEgGEklN/vSxqD4vf7JUkpKSmSpLq6Op05c0a5ubl2m4kTJyozM1Ner1eS5PV6NXnyZDucSFJeXp4CgYAaGhp6/Zz29nYFAoGgBQAARK9+B5TOzk4tXbpUM2bM0A033CBJ8vl8SkhIUHJyclBbl8sln89nt/l8OOna37WvN2VlZXI6nfaSkZHR37IBAMAQ0O+AUlRUpHfffVebNm0KZz29Ki0tld/vt5fm5uYB/0wAABA5/XoWT3FxsSorK7Vz505deeWV9na3262Ojg61tbUF9aK0tLTI7Xbbbfbu3Rv0fl13+XS16S4xMVGJiYn9KRUAAAxBIfWgWJal4uJivfrqq9q+fbvGjx8ftH/q1KmKj49XdXW1va2xsVFNTU3yeDySJI/Ho/r6erW2ttptqqqq5HA4lJWVdSnHAgAAokRIPShFRUXauHGj/vSnP2nMmDH2mBGn06nLLrtMTqdTCxcuVElJiVJSUuRwOLRkyRJ5PB7l5ORIkmbPnq2srCwtWLBAq1atks/n04oVK1RUVEQvCQAAkBRiQHn++eclSXfccUfQ9vXr1+tb3/qWJKm8vFyxsbEqKChQe3u78vLytGbNGrvtiBEjVFlZqcWLF8vj8WjUqFEqLCzUE088cWlHAgAAosYlzYMSKQM+D0pNWdCq99Cx8H8GmAcFAIaZQZsHBQAAYCAQUAAAgHEIKAAAwDgEFAAAYJx+TdQGhEt51cGgdQbNAgAkelAAAICBCCgAAMA4BBQAAGAcxqAgYnKa1vXcWJMavD6zdHCKAQAYhR4UAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcZpKF0XjaMQAMTwQUGMV76FjQeo66T4f/zOAVAwCIGC7xAAAA4xBQAACAcQgoAADAOIxBwdBXUxa8PrM0MnUAAMKGHhQAAGAcAgoAADAOAQUAABiHgAIAAIxDQAEAAMYhoAAAAOMQUAAAgHGYBwVDSveHB0rSMv4WA0DUoQcFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcZpDAkJLTtK7nxqtTB78QAMCAIqBgyPMeOha07pkZoUIAAGFDQEH0qSnruW1m6eDXAQDoN8agAAAA4xBQAACAcbjE04vuYxoAAMDgogcFAAAYh4ACAACMQ0ABAADGYQwKop730DHVnj1ory+bNSGC1QAA+oIeFAAAYJyQA8rOnTt1zz33KD09XTExMXrttdeC9luWpccff1xjx47VZZddptzcXL3//vtBbY4fP6758+fL4XAoOTlZCxcu1MmTJy/pQAAAQPQIOaCcOnVKN910kyoqKnrdv2rVKq1evVpr167Vnj17NGrUKOXl5en06dN2m/nz56uhoUFVVVWqrKzUzp07tWjRov4fBXAROU3r7KXXmWYBAEYJeQzKnDlzNGfOnF73WZalZ599VitWrNC9994rSfrNb34jl8ul1157TfPmzdN7772nrVu36u2339a0adMkSc8995zuuusuPfPMM0pPT7+EwwGYxwYAokFYx6AcPnxYPp9Pubm59jan06ns7Gx5vV5JktfrVXJysh1OJCk3N1exsbHas2dPr+/b3t6uQCAQtAAAgOgV1oDi8/kkSS6XK2i7y+Wy9/l8PqWlpQXtj4uLU0pKit2mu7KyMjmdTnvJyMgIZ9kAAMAwQ+IuntLSUvn9fntpbm6OdEkAAGAAhTWguN1uSVJLS0vQ9paWFnuf2+1Wa2tr0P6zZ8/q+PHjdpvuEhMT5XA4ghYAABC9whpQxo8fL7fbrerqantbIBDQnj175PF4JEkej0dtbW2qq6uz22zfvl2dnZ3Kzs4OZzkAAGCICvkunpMnT+qDDz6w1w8fPqz9+/crJSVFmZmZWrp0qZ566ilde+21Gj9+vH74wx8qPT1dc+fOlSRNmjRJd955px566CGtXbtWZ86cUXFxsebNm8cdPAAAQFI/Aso//vEPzZw5014vKSmRJBUWFmrDhg165JFHdOrUKS1atEhtbW26/fbbtXXrViUlJdmv+d3vfqfi4mJ97WtfU2xsrAoKCrR69eowHA5wcd5Dx6RDD9vrnqtTe284s3SQKgIAdBdjWZYV6SJCFQgE5HQ65ff7B2Q8ivelhy/eCFHLDiwEFAAIq1B+v4fEXTwAAGB4IaAAAADjEFAAAIBxCCgAAMA4BBQAAGAcAgoAADAOAQUAABgn5InagGGjpqzntu5zo/SlDQAgZPSgAAAA49CDAnTjPXTsvPtqzx6UJC2bNWGwygGAYYmAAoRb98s+XPIBgJBxiQcAABiHgAIAAIzDJR4gBDlN6z79Q01qZAsBgChHDwoAADAOPShAP3S/08dzNT0qABBO9KAAAADj0IMCDDRmmwWAkNGDAgAAjENAAQAAxiGgAAAA4xBQAACAcRgkC0RAedXBoHUePggAwQgoQBhc6AnIEvOkAECouMQDAACMQ0ABAADG4RIPYCjGqQAYzggowBBFgAEQzQgoQATkNK0LWi+vWnTRNtIzF25Tk8oU+gCiBgEFGAQXu8sHABCMgAIYoGdvCQAMbwQUYIjwvvRw0HpOhOoAgMHAbcYAAMA4BBQAAGAcLvEA0aSmrOc27uwBMATRgwIAAIxDDwoQJbrfyhzSAwrpeQFgGAIKEKW6Akvt2U9nnGWmWQBDCQEFQO+696r0pUeFnhgAYUJAAaKcPQlczecu+fQnNBA+AAwiAgqA/usttABAGBBQgOHo/weLSxpYCwADiIAC4LzCEmD6M5YFwLBHQAGGCWOeqByusSyMiQGiGhO1AQAA49CDAmBoYEAuMKwQUADYInEZyHvomD2ZnMSEcgA+RUAB0GehBpjeBtX29h72XC1S8HwtvbS/4EBdBuQCUYOAAmDA9KdHZsB7cXoJMeVVB4M20YsDRF5EB8lWVFToqquuUlJSkrKzs7V3795IlgNgCPAeOha0AIhOEetB+f3vf6+SkhKtXbtW2dnZevbZZ5WXl6fGxkalpaVFqiwAUSCkByXWlCmnqXvQeaZHm97e//M8C5/psS1qcSkNgyDGsiwrEh+cnZ2tW2+9Vb/61a8kSZ2dncrIyNCSJUv02GOPXfC1gUBATqdTfr9fDocj7LV5X3o47O8JYGD0dZzLQKvNXNRj2+fH1px37Ex/H8IYjvfpb7CIZEBh/pshLZTf74j0oHR0dKiurk6lpZ/9pYqNjVVubq68Xm+P9u3t7Wpvb7fX/X6/pE8PdCCc+l/7xRsBMMKbDUciXYIk6fSpkz22ff7fksCp072/sHJl8PqXfyBJqtj+gb3p1o96P8bpV6V8ttKXfw+719CX1+z8edDq3v873rOGcP1b3O2zJNn/PWy9/XccoN8ChF/X73af+kasCPj3v/9tSbJ2794dtH358uXW9OnTe7RfuXKlJYmFhYWFhYUlCpbm5uaLZoUhcRdPaWmpSkpK7PXOzk4dP35cqampiomJCetnBQIBZWRkqLm5eUAuHyE0nA+zcD7MwvkwD+fkwizL0okTJ5Senn7RthEJKF/4whc0YsQItbS0BG1vaWmR2+3u0T4xMVGJiYlB25KTkweyRDkcDv5yGYTzYRbOh1k4H+bhnJyf0+nsU7uI3GackJCgqVOnqrq62t7W2dmp6upqeTyeSJQEAAAMErFLPCUlJSosLNS0adM0ffp0Pfvsszp16pS+/e1vR6okAABgiIgFlG9+85v6+OOP9fjjj8vn82nKlCnaunWrXC5XpEqS9OnlpJUrV/a4pITI4HyYhfNhFs6HeTgn4ROxeVAAAADOJ6JT3QMAAPSGgAIAAIxDQAEAAMYhoAAAAOMMy4BSUVGhq666SklJScrOztbevXsv2H7z5s2aOHGikpKSNHnyZP3lL38ZpEqHh1DOx4YNGxQTExO0JCUlDWK10W3nzp265557lJ6erpiYGL322msXfc1bb72lW265RYmJibrmmmu0YcOGAa9zuAj1fLz11ls9vh8xMTHy+XyDU3CUKysr06233qoxY8YoLS1Nc+fOVWNj40Vfx29I/wy7gPL73/9eJSUlWrlypd555x3ddNNNysvLU2tra6/td+/erQceeEALFy7Uvn37NHfuXM2dO1fvvvvuIFcenUI9H9KnMzQePXrUXv71r38NYsXR7dSpU7rppptUUVHRp/aHDx9Wfn6+Zs6cqf3792vp0qX6zne+o23btg1wpcNDqOejS2NjY9B3JC0tbYAqHF527NihoqIi1dbWqqqqSmfOnNHs2bN16tSp876G35BLEJ7H/w0d06dPt4qKiuz1c+fOWenp6VZZWVmv7b/xjW9Y+fn5Qduys7Ot7373uwNa53AR6vlYv3695XQ6B6m64U2S9eqrr16wzSOPPGJdf/31Qdu++c1vWnl5eQNY2fDUl/NRU1NjSbI++eSTQalpuGttbbUkWTt27DhvG35D+m9Y9aB0dHSorq5Oubm59rbY2Fjl5ubK6/X2+hqv1xvUXpLy8vLO2x5915/zIUknT57UuHHjlJGRoXvvvVcNDQ2DUS56wffDTFOmTNHYsWM1a9Ys7dq1K9LlRC2/3y9JSklJOW8bviP9N6wCyn/+8x+dO3eux2y1LpfrvNdofT5fSO3Rd/05H9ddd51+/etf609/+pNeeeUVdXZ26rbbbtNHH300GCWjm/N9PwKBgP73v/9FqKrha+zYsVq7dq3++Mc/6o9//KMyMjJ0xx136J133ol0aVGns7NTS5cu1YwZM3TDDTectx2/If0Xsanugf7weDxBD5S87bbbNGnSJL3wwgt68sknI1gZEHnXXXedrrvuOnv9tttu04cffqjy8nL99re/jWBl0aeoqEjvvvuu/v73v0e6lKg1rHpQvvCFL2jEiBFqaWkJ2t7S0iK3293ra9xud0jt0Xf9OR/dxcfH6+abb9YHH3wwECXiIs73/XA4HLrssssiVBU+b/r06Xw/wqy4uFiVlZWqqanRlVdeecG2/Ib037AKKAkJCZo6daqqq6vtbZ2dnaqurg76f+Wf5/F4gtpLUlVV1Xnbo+/6cz66O3funOrr6zV27NiBKhMXwPfDfPv37+f7ESaWZam4uFivvvqqtm/frvHjx1/0NXxHLkGkR+kOtk2bNlmJiYnWhg0brAMHDliLFi2ykpOTLZ/PZ1mWZS1YsMB67LHH7Pa7du2y4uLirGeeecZ67733rJUrV1rx8fFWfX19pA4hqoR6Pn784x9b27Ztsz788EOrrq7OmjdvnpWUlGQ1NDRE6hCiyokTJ6x9+/ZZ+/btsyRZv/jFL6x9+/ZZ//rXvyzLsqzHHnvMWrBggd3+0KFD1siRI63ly5db7733nlVRUWGNGDHC2rp1a6QOIaqEej7Ky8ut1157zXr//fet+vp66/vf/74VGxtrvfnmm5E6hKiyePFiy+l0Wm+99ZZ19OhRe/nvf/9rt+E3JHyGXUCxLMt67rnnrMzMTCshIcGaPn26VVtba+/7yle+YhUWFga1/8Mf/mBNmDDBSkhIsK6//nrrjTfeGOSKo1so52Pp0qV2W5fLZd11113WO++8E4Gqo1PXbardl65zUFhYaH3lK1/p8ZopU6ZYCQkJ1tVXX22tX79+0OuOVqGej5/+9KfWF7/4RSspKclKSUmx7rjjDmv79u2RKT4K9XYuJAX9nec3JHxiLMuyBrvXBgAA4EKG1RgUAAAwNBBQAACAcQgoAADAOAQUAABgHAIKAAAwDgEFAAAYh4ACAACMQ0ABAADGIaAAAADjEFAAAIBxCCgAAMA4BBQAAGCc/wcAxJnJocKCGQAAAABJRU5ErkJggg==", "text/plain": [ - "
" + "{'imagenet_unconditional_model_embedding': array([13181., 10870., 12008., 10130., 9628.], dtype=float32),\n", + " 'imagenet_cs1_embedding': array([21638., 15387., 14466., 17863., 15719.], dtype=float32),\n", + " 'imagenet_cs10_embedding': array([17963., 15701., 13769., 15598., 14699.], dtype=float32),\n", + " 'imagenet_biggan_embedding': array([33609., 29164., 29311., 30098., 26772.], dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([39859., 34397., 34759., 37063., 34212.], dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([39529., 34206., 34731., 36771., 35564.], dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([25483., 19751., 22153., 24993., 22030.], dtype=float32),\n", + " 'imagenet_wukong_embedding': array([42367., 35982., 37357., 38585., 34325.], dtype=float32),\n", + " 'imagenet_adm_embedding': array([37571., 34480., 34681., 33519., 29859.], dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([40835., 32620., 35570., 36513., 35667.], dtype=float32)}" ] }, + "execution_count": 24, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "plt.hist(data[:, 0].numpy(), bins=100, alpha=0.5, label=\"data\")\n", - "plt.hist(testset[:10_000, 0].numpy(), bins=100, alpha=0.5, label=\"testset\")" + "results_mmd_poly_kid" ] }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_mmd_poly_kid.npy\", results_mmd_poly_kid)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "results_c2st_knn = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n", + "Starting imagenet_cs10_embedding\n", + "Starting imagenet_biggan_embedding\n", + "Starting imagenet_sdv4_embedding\n", + "Starting imagenet_sdv5_embedding\n", + "Starting imagenet_vqdm_embedding\n", + "Starting imagenet_wukong_embedding\n", + "Starting imagenet_adm_embedding\n", + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"c2st_knn\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, n_folds=2)\n", + " metric_values.append(m)\n", + " results_c2st_knn[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(6.3563)" + "{'imagenet_unconditional_model_embedding': array([[0.65345 ],\n", + " [0.64765 ],\n", + " [0.650725],\n", + " [0.652725],\n", + " [0.652275]], dtype=float32),\n", + " 'imagenet_cs1_embedding': array([[0.63035 ],\n", + " [0.62975 ],\n", + " [0.63135 ],\n", + " [0.629375],\n", + " [0.63405 ]], dtype=float32),\n", + " 'imagenet_cs10_embedding': array([[0.6517 ],\n", + " [0.6506 ],\n", + " [0.6573 ],\n", + " [0.654 ],\n", + " [0.65605]], dtype=float32),\n", + " 'imagenet_biggan_embedding': array([[0.75415 ],\n", + " [0.749725],\n", + " [0.7522 ],\n", + " [0.750975],\n", + " [0.754475]], dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([[0.7913 ],\n", + " [0.78645 ],\n", + " [0.78665 ],\n", + " [0.78685 ],\n", + " [0.790925]], dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([[0.7921 ],\n", + " [0.786525],\n", + " [0.7872 ],\n", + " [0.790975],\n", + " [0.792425]], dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([[0.774775],\n", + " [0.76665 ],\n", + " [0.76835 ],\n", + " [0.773775],\n", + " [0.77215 ]], dtype=float32),\n", + " 'imagenet_wukong_embedding': array([[0.790025],\n", + " [0.786625],\n", + " [0.7873 ],\n", + " [0.7881 ],\n", + " [0.790225]], dtype=float32),\n", + " 'imagenet_adm_embedding': array([[0.76605 ],\n", + " [0.7602 ],\n", + " [0.763725],\n", + " [0.760025],\n", + " [0.7645 ]], dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([[0.796975],\n", + " [0.791325],\n", + " [0.791325],\n", + " [0.79765 ],\n", + " [0.79655 ]], dtype=float32)}" ] }, - "execution_count": 90, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metric_fn(testset[:20_000], data)" + "results_c2st_knn" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_c2st_knn.npy\", results_c2st_knn)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "results_c2st_nn = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n", + "Starting imagenet_cs10_embedding\n", + "Starting imagenet_biggan_embedding\n", + "Starting imagenet_sdv4_embedding\n", + "Starting imagenet_sdv5_embedding\n", + "Starting imagenet_vqdm_embedding\n", + "Starting imagenet_wukong_embedding\n", + "Starting imagenet_adm_embedding\n", + "Starting imagenet_midjourney_embedding\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"c2st_nn\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, n_folds=2)\n", + " metric_values.append(m)\n", + " results_c2st_nn[dname] = np.array(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "np.save(\"results_c2st_nn.npy\", results_c2st_nn)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.0216)" + "{'imagenet_unconditional_model_embedding': array([[0.7191 ],\n", + " [0.71785 ],\n", + " [0.72095 ],\n", + " [0.716875],\n", + " [0.7206 ]], dtype=float32),\n", + " 'imagenet_cs1_embedding': array([[0.771325],\n", + " [0.761475],\n", + " [0.771525],\n", + " [0.767575],\n", + " [0.7738 ]], dtype=float32),\n", + " 'imagenet_cs10_embedding': array([[0.7608 ],\n", + " [0.7621 ],\n", + " [0.764 ],\n", + " [0.7643 ],\n", + " [0.757775]], dtype=float32),\n", + " 'imagenet_biggan_embedding': array([[0.8693 ],\n", + " [0.864775],\n", + " [0.862925],\n", + " [0.861475],\n", + " [0.854925]], dtype=float32),\n", + " 'imagenet_sdv4_embedding': array([[0.92205 ],\n", + " [0.920025],\n", + " [0.932025],\n", + " [0.92415 ],\n", + " [0.92075 ]], dtype=float32),\n", + " 'imagenet_sdv5_embedding': array([[0.920525],\n", + " [0.930475],\n", + " [0.923075],\n", + " [0.920125],\n", + " [0.92495 ]], dtype=float32),\n", + " 'imagenet_vqdm_embedding': array([[0.8448 ],\n", + " [0.8455 ],\n", + " [0.84815],\n", + " [0.8496 ],\n", + " [0.84925]], dtype=float32),\n", + " 'imagenet_wukong_embedding': array([[0.920125],\n", + " [0.914625],\n", + " [0.92035 ],\n", + " [0.917275],\n", + " [0.91415 ]], dtype=float32),\n", + " 'imagenet_adm_embedding': array([[0.85505 ],\n", + " [0.855725],\n", + " [0.854125],\n", + " [0.8574 ],\n", + " [0.849125]], dtype=float32),\n", + " 'imagenet_midjourney_embedding': array([[0.94325 ],\n", + " [0.9454 ],\n", + " [0.9513 ],\n", + " [0.94335 ],\n", + " [0.940975]], dtype=float32)}" ] }, - "execution_count": 88, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metric_fn2(testset[:data.shape[0]], data, num_projections=5000)" + "results_c2st_nn" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "results_c2st_rf = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting imagenet_unconditional_model_embedding\n", + "Starting imagenet_cs1_embedding\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[50], line 14\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 13\u001b[0m data_syn \u001b[38;5;241m=\u001b[39m get_dataset(dname)(\u001b[38;5;241m20_000\u001b[39m, \u001b[38;5;241m2048\u001b[39m, permute\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m---> 14\u001b[0m m \u001b[38;5;241m=\u001b[39m \u001b[43mmetric_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_syn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m metric_values\u001b[38;5;241m.\u001b[39mappend(m)\n\u001b[1;32m 16\u001b[0m results_c2st_rf[dname] \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(metric_values)\n", + "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/utils.py:24\u001b[0m, in \u001b[0;36mregister_metric..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# Call the original function\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m metric \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# Convert output to tensor\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(metric, torch\u001b[38;5;241m.\u001b[39mTensor):\n", + "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/c2st.py:168\u001b[0m, in \u001b[0;36mc2st_rf\u001b[0;34m(X, Y, seed, n_folds, metric, z_score, n_estimators, clf_kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m clf_class \u001b[38;5;241m=\u001b[39m RandomForestClassifier\n\u001b[1;32m 166\u001b[0m clf_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_estimators\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m n_estimators\n\u001b[0;32m--> 168\u001b[0m scores_ \u001b[38;5;241m=\u001b[39m \u001b[43mc2st_scores\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 170\u001b[0m \u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_folds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 173\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mz_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mz_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mnoise_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbosity\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 177\u001b[0m \u001b[43m \u001b[49m\u001b[43mclf_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclf_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43mclf_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclf_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 181\u001b[0m scores \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(scores_)\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 182\u001b[0m value \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(np\u001b[38;5;241m.\u001b[39matleast_1d(scores))\n", + "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/c2st.py:357\u001b[0m, in \u001b[0;36mc2st_scores\u001b[0;34m(X, Y, seed, n_folds, metric, z_score, noise_scale, verbosity, clf_class, clf_kwargs)\u001b[0m\n\u001b[1;32m 354\u001b[0m target \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate((np\u001b[38;5;241m.\u001b[39mzeros((X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],)), np\u001b[38;5;241m.\u001b[39mones((Y\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],))))\n\u001b[1;32m 356\u001b[0m shuffle \u001b[38;5;241m=\u001b[39m KFold(n_splits\u001b[38;5;241m=\u001b[39mn_folds, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, random_state\u001b[38;5;241m=\u001b[39mseed)\n\u001b[0;32m--> 357\u001b[0m scores \u001b[38;5;241m=\u001b[39m \u001b[43mcross_val_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshuffle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscoring\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbosity\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m scores\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 209\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 210\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 211\u001b[0m )\n\u001b[1;32m 212\u001b[0m ):\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m 219\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m 220\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 223\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:714\u001b[0m, in \u001b[0;36mcross_val_score\u001b[0;34m(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, params, pre_dispatch, error_score)\u001b[0m\n\u001b[1;32m 711\u001b[0m \u001b[38;5;66;03m# To ensure multimetric format is not supported\u001b[39;00m\n\u001b[1;32m 712\u001b[0m scorer \u001b[38;5;241m=\u001b[39m check_scoring(estimator, scoring\u001b[38;5;241m=\u001b[39mscoring)\n\u001b[0;32m--> 714\u001b[0m cv_results \u001b[38;5;241m=\u001b[39m \u001b[43mcross_validate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 717\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 718\u001b[0m \u001b[43m \u001b[49m\u001b[43mgroups\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[43m \u001b[49m\u001b[43mscoring\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mscore\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mscorer\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 720\u001b[0m \u001b[43m \u001b[49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m \u001b[49m\u001b[43mfit_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfit_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 724\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 725\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 726\u001b[0m \u001b[43m \u001b[49m\u001b[43merror_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 727\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 728\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cv_results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_score\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 209\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 210\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 211\u001b[0m )\n\u001b[1;32m 212\u001b[0m ):\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m 219\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m 220\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 223\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:425\u001b[0m, in \u001b[0;36mcross_validate\u001b[0;34m(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, params, pre_dispatch, return_train_score, return_estimator, return_indices, error_score)\u001b[0m\n\u001b[1;32m 422\u001b[0m \u001b[38;5;66;03m# We clone the estimator to make sure that all the folds are\u001b[39;00m\n\u001b[1;32m 423\u001b[0m \u001b[38;5;66;03m# independent, and that it is pickle-able.\u001b[39;00m\n\u001b[1;32m 424\u001b[0m parallel \u001b[38;5;241m=\u001b[39m Parallel(n_jobs\u001b[38;5;241m=\u001b[39mn_jobs, verbose\u001b[38;5;241m=\u001b[39mverbose, pre_dispatch\u001b[38;5;241m=\u001b[39mpre_dispatch)\n\u001b[0;32m--> 425\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mparallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 426\u001b[0m \u001b[43m \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_fit_and_score\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 427\u001b[0m \u001b[43m \u001b[49m\u001b[43mclone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 428\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 429\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 430\u001b[0m \u001b[43m \u001b[49m\u001b[43mscorer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscorers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 431\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mfit_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrouted_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m \u001b[49m\u001b[43mscore_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrouted_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscorer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_train_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_train_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 438\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_times\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_estimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_estimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 440\u001b[0m \u001b[43m \u001b[49m\u001b[43merror_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 442\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m _warn_or_raise_about_fit_failures(results, error_score)\n\u001b[1;32m 447\u001b[0m \u001b[38;5;66;03m# For callable scoring, the return type is only know after calling. If the\u001b[39;00m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;66;03m# return type is a dictionary, the error scores can now be inserted with\u001b[39;00m\n\u001b[1;32m 449\u001b[0m \u001b[38;5;66;03m# the correct key.\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:67\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 62\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m 63\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 64\u001b[0m (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m 66\u001b[0m )\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1863\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1861\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_sequential_output(iterable)\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 1863\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1865\u001b[0m \u001b[38;5;66;03m# Let's create an ID that uniquely identifies the current call. If the\u001b[39;00m\n\u001b[1;32m 1866\u001b[0m \u001b[38;5;66;03m# call is interrupted early and that the same instance is immediately\u001b[39;00m\n\u001b[1;32m 1867\u001b[0m \u001b[38;5;66;03m# re-used, this id will be used to prevent workers that were\u001b[39;00m\n\u001b[1;32m 1868\u001b[0m \u001b[38;5;66;03m# concurrently finalizing a task from the previous call to run the\u001b[39;00m\n\u001b[1;32m 1869\u001b[0m \u001b[38;5;66;03m# callback.\u001b[39;00m\n\u001b[1;32m 1870\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1792\u001b[0m, in \u001b[0;36mParallel._get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1790\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_batches \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1791\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1792\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1793\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_completed_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1794\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_progress()\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:129\u001b[0m, in \u001b[0;36m_FuncWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 127\u001b[0m config \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig):\n\u001b[0;32m--> 129\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:890\u001b[0m, in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, score_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m 888\u001b[0m estimator\u001b[38;5;241m.\u001b[39mfit(X_train, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfit_params)\n\u001b[1;32m 889\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 890\u001b[0m \u001b[43mestimator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfit_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 892\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m 893\u001b[0m \u001b[38;5;66;03m# Note fit time as time until error\u001b[39;00m\n\u001b[1;32m 894\u001b[0m fit_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m start_time\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/base.py:1351\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1344\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1346\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1347\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1348\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1349\u001b[0m )\n\u001b[1;32m 1350\u001b[0m ):\n\u001b[0;32m-> 1351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/ensemble/_forest.py:489\u001b[0m, in \u001b[0;36mBaseForest.fit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 478\u001b[0m trees \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_estimator(append\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_more_estimators)\n\u001b[1;32m 481\u001b[0m ]\n\u001b[1;32m 483\u001b[0m \u001b[38;5;66;03m# Parallel loop: we prefer the threading backend as the Cython code\u001b[39;00m\n\u001b[1;32m 484\u001b[0m \u001b[38;5;66;03m# for fitting the trees is internally releasing the Python GIL\u001b[39;00m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;66;03m# making threading more efficient than multiprocessing in\u001b[39;00m\n\u001b[1;32m 486\u001b[0m \u001b[38;5;66;03m# that case. However, for joblib 0.12+ we respect any\u001b[39;00m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# parallel_backend contexts set at a higher level,\u001b[39;00m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;66;03m# since correctness does not rely on using threads.\u001b[39;00m\n\u001b[0;32m--> 489\u001b[0m trees \u001b[38;5;241m=\u001b[39m \u001b[43mParallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 491\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[43m \u001b[49m\u001b[43mprefer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mthreads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_parallel_build_trees\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 495\u001b[0m \u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 496\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbootstrap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 499\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 501\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrees\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 502\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 503\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 504\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples_bootstrap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_samples_bootstrap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 505\u001b[0m \u001b[43m \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 506\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 507\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrees\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 508\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 510\u001b[0m \u001b[38;5;66;03m# Collect newly grown trees\u001b[39;00m\n\u001b[1;32m 511\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mestimators_\u001b[38;5;241m.\u001b[39mextend(trees)\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:67\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 62\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m 63\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 64\u001b[0m (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m 66\u001b[0m )\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1863\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1861\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_sequential_output(iterable)\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 1863\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1865\u001b[0m \u001b[38;5;66;03m# Let's create an ID that uniquely identifies the current call. If the\u001b[39;00m\n\u001b[1;32m 1866\u001b[0m \u001b[38;5;66;03m# call is interrupted early and that the same instance is immediately\u001b[39;00m\n\u001b[1;32m 1867\u001b[0m \u001b[38;5;66;03m# re-used, this id will be used to prevent workers that were\u001b[39;00m\n\u001b[1;32m 1868\u001b[0m \u001b[38;5;66;03m# concurrently finalizing a task from the previous call to run the\u001b[39;00m\n\u001b[1;32m 1869\u001b[0m \u001b[38;5;66;03m# callback.\u001b[39;00m\n\u001b[1;32m 1870\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1792\u001b[0m, in \u001b[0;36mParallel._get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1790\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_batches \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1791\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1792\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1793\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_completed_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1794\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_progress()\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:129\u001b[0m, in \u001b[0;36m_FuncWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 127\u001b[0m config \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig):\n\u001b[0;32m--> 129\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/ensemble/_forest.py:192\u001b[0m, in \u001b[0;36m_parallel_build_trees\u001b[0;34m(tree, bootstrap, X, y, sample_weight, tree_idx, n_trees, verbose, class_weight, n_samples_bootstrap, missing_values_in_feature_mask)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m class_weight \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbalanced_subsample\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 190\u001b[0m curr_sample_weight \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m=\u001b[39m compute_sample_weight(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbalanced\u001b[39m\u001b[38;5;124m\"\u001b[39m, y, indices\u001b[38;5;241m=\u001b[39mindices)\n\u001b[0;32m--> 192\u001b[0m \u001b[43mtree\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurr_sample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_input\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 200\u001b[0m tree\u001b[38;5;241m.\u001b[39m_fit(\n\u001b[1;32m 201\u001b[0m X,\n\u001b[1;32m 202\u001b[0m y,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 205\u001b[0m missing_values_in_feature_mask\u001b[38;5;241m=\u001b[39mmissing_values_in_feature_mask,\n\u001b[1;32m 206\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/tree/_classes.py:472\u001b[0m, in \u001b[0;36mBaseDecisionTree._fit\u001b[0;34m(self, X, y, sample_weight, check_input, missing_values_in_feature_mask)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 462\u001b[0m builder \u001b[38;5;241m=\u001b[39m BestFirstTreeBuilder(\n\u001b[1;32m 463\u001b[0m splitter,\n\u001b[1;32m 464\u001b[0m min_samples_split,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmin_impurity_decrease,\n\u001b[1;32m 470\u001b[0m )\n\u001b[0;32m--> 472\u001b[0m \u001b[43mbuilder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuild\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 474\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_outputs_ \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m is_classifier(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 475\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_classes_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_classes_[\u001b[38;5;241m0\u001b[39m]\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "metric = \"c2st_rf\"\n", + "metric_fn = get_metric(metric)\n", + "for dname in datasets:\n", + " print(\"Starting \", dname)\n", + " metric_values = []\n", + " for j in range(5):\n", + " data_test = testset[j*20_000:(j+1)*20_000]\n", + " if dname == \"imagenet_midjourney_embedding\":\n", + " data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n", + " data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n", + " else:\n", + " data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n", + " m = metric_fn(data_test, data_syn, n_folds=2)\n", + " metric_values.append(m)\n", + " results_c2st_rf[dname] = np.array(metric_values)" ] }, { @@ -255,7 +930,40 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "np.save(\"results_c2st_rf.npy\", results_c2st_rf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_c2st_rf" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.9410])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metric_fn = get_metric(\"c2st_nn\")\n", + "m = metric_fn(data_test, data_syn, n_folds=2)\n", + "m" + ] } ], "metadata": { diff --git a/docs/notebooks/fid/model_table.ipynb b/docs/notebooks/fid/model_table.ipynb new file mode 100644 index 0000000..0db4bfb --- /dev/null +++ b/docs/notebooks/fid/model_table.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "results_fid = np.load('results_fid.npy', allow_pickle=True).item()" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "def print_mean(means, scientific=True):\n", + " out = ''\n", + " i = np.argmin(means[:3])\n", + " j = np.argmin(means[3:]) + 3\n", + " for k,mean in enumerate(means):\n", + " \n", + " if k == i:\n", + " out += '$\\\\mathbf{'\n", + " elif k == j:\n", + " out += '$\\\\mathbf{'\n", + " else:\n", + " out += '$'\n", + " \n", + " if scientific:\n", + " out += f'{mean:.1e} ' # Scientific notation for numbers smaller than 1e-3\n", + " else:\n", + " out += f'{mean:.2f} '\n", + " \n", + " if k == i or k == j:\n", + " out += '}$ &'\n", + " else:\n", + " out += '$ &'\n", + " \n", + " out = out[:-2] # Remove the last ' & '\n", + " out += '\\\\\\\\' # Add '\\\\'\n", + " print(out)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$\\mathbf{6.2e+00 }$ &$6.4e+00 $ &$7.0e+00 $ &$1.3e+01 $ &$1.7e+01 $ &$1.7e+01 $ &$\\mathbf{1.1e+01 }$ &$1.9e+01 $ &$1.3e+01 $ &$1.7e+01 $\\\\\n" + ] + } + ], + "source": [ + "print_mean(means)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$2.3e-02 $ &$\\mathbf{2.2e-02 }$ &$2.5e-02 $ &$5.1e-02 $ &$5.6e-02 $ &$5.6e-02 $ &$\\mathbf{4.1e-02 }$ &$5.6e-02 $ &$5.0e-02 $ &$4.8e-02 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_sw.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$6.6e-05 $ &$\\mathbf{6.2e-05 }$ &$8.5e-05 $ &$1.9e-04 $ &$2.1e-04 $ &$2.1e-04 $ &$\\mathbf{1.5e-04 }$ &$2.0e-04 $ &$1.9e-04 $ &$1.8e-04 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_mmd_rbf64.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$2.5e-01 $ &$\\mathbf{2.4e-01 }$ &$3.2e-01 $ &$6.3e-01 $ &$6.9e-01 $ &$6.9e-01 $ &$\\mathbf{5.1e-01 }$ &$6.4e-01 $ &$6.5e-01 $ &$6.0e-01 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_mmd_lin.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$\\mathbf{1.1e+04 }$ &$1.7e+04 $ &$1.6e+04 $ &$3.0e+04 $ &$3.6e+04 $ &$3.6e+04 $ &$\\mathbf{2.3e+04 }$ &$3.8e+04 $ &$3.4e+04 $ &$3.6e+04 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_mmd_poly_kid.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$0.65 $ &$\\mathbf{0.63 }$ &$0.65 $ &$\\mathbf{0.75 }$ &$0.79 $ &$0.79 $ &$0.77 $ &$0.79 $ &$0.76 $ &$0.79 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_c2st_knn.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means, False)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$\\mathbf{0.72 }$ &$0.77 $ &$0.76 $ &$0.86 $ &$0.92 $ &$0.92 $ &$\\mathbf{0.85 }$ &$0.92 $ &$0.85 $ &$0.94 $\\\\\n" + ] + } + ], + "source": [ + "results_fid = np.load('results_c2st_nn.npy', allow_pickle=True).item()\n", + "means = []\n", + "stds = []\n", + "for key in results_fid.keys():\n", + " means.append(np.mean(results_fid[key]))\n", + " stds.append(np.std(results_fid[key]))\n", + "print_mean(means, False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "labproject", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/fid/results_c2st_knn.npy b/docs/notebooks/fid/results_c2st_knn.npy new file mode 100644 index 0000000000000000000000000000000000000000..e7868e26bff3f4fa888c7ccee77105438e57e3f2 GIT binary patch literal 1232 zcmbu-OH9*n9Ki7o6m`B(0Z|k;m31ggKorEs54Jh!HV0H5GHT+*#G4wUH*Y4y0~asGD;|g@#+#2s+hS{-NB>Ei^w)nr`R1Sh%}C$D?m-(h zN}ZwwF&kItHV1vWi>`Cf!?HpYPKwEjAZp_Pm(GgnJe%Z3MD^QUTjv;a{NJ*V7ssVs zW+d;4%Zlj93>B+%CH$%>6}Mh?>A6pN+@j3pkI(`Rss<_DCcW{HZSe6#{~b`N<(%GKNdt zQJ1pY)ZDchVrdse#$sF8EPN!`bgHz)FH^_-KW- zDxPh$RID{Bsn@zVycc2UWUh31Er@$RW zA!Bj(8ty*L1s2;)w?lUy4FsUx=Z8B_{7_vHfcua419}Pi3QqWV?wBkkEVAE__iJ*1 z#RI0?smo;$1n&_B&Rk~T<0}Rl!fG)?Fj#QaxW9r!1{>DcgDgf&wo7NT`!zE z=L71x7Xr_HfZYUp_~?4;mB^%oqjFA>#JnY%HS|N8-pgX2srT#p)sip-ZUyycswoUl RXG4Gk1fzV8BRI%=egPS-h8q9? literal 0 HcmV?d00001 diff --git a/docs/notebooks/fid/results_c2st_nn.npy b/docs/notebooks/fid/results_c2st_nn.npy new file mode 100644 index 0000000000000000000000000000000000000000..efe8738dfe544b937b68985bf77ba3cc2f7d3606 GIT binary patch literal 1232 zcmbu-%TE(g6u{vrSka<_%2P$EA}xp&1f(KzghEv+1xL{-59>^4XiM8EXL^e?f~X0P zYJA18Fk#`w5M8-)r5lV39}8pRPFHH;l0QJBGlMe@b?IG9GUwiW`Q_%`_x5%7`b^Xq zb=k%XiGXBlw%R6oY&BNfa9ko1EBfOSFDPR-8%+rEJP~3?1^M4zRb!pB{@=1(%?^mk z*l5ZTh)aSamW+}xD@kn1;+KR#Tucxt86Xyr1TqX{)XK@XUX1m8! z_2&>zQlo+e*D_fYMfEr_Cp|Zd>u6kW_hd6Bc{YI-uh-l4OJ3^70|axq^bl^C#Eo{Q zQa%!nu|Yu;h(9UH8GM+8<02dN$Kt#Y^$Rgh;Q6o^#7!Y8q^9JBI@azw2JLY#SR+F^opIKLYcjt?1Y@N?dSl-UZcalSo}T+57UD%^r;_kKS4)Eae^OfFt`Uaw^niMXgsL9^{P9b8-=Br2-H0g zp=3S^?q?Ce27-Yy4$R}vx^Z}s2AapHwwMCA~5qh3V4KIONKj! ze9Yh;)!bIa1sad(Zky`9?_l7m#K54{2Z1RcBtkyGc7m>q8@`l`#KoXN?$G2;MRwD8 zT$fL%vQqiEVg&LQB9Q(t0?*AOfF}udWjr-@Rj@~6Pbq9Sjb5GYQQ13R`{8Z=04y)| z!@Vm5(0)-4?j?AdW6F$AA{ORH;z>yqQif<+)6XdSSsMFv{hX?Q97j1+g5kM@d>sfc RzKjB%C&+L~mY|Pw{08QGjc5P> literal 0 HcmV?d00001 diff --git a/docs/notebooks/fid/results_fid.npy b/docs/notebooks/fid/results_fid.npy new file mode 100644 index 0000000000000000000000000000000000000000..900a6a1ab8d257b683143a43a3f746b71ad614ac GIT binary patch literal 1212 zcmbu-TTIhn6bJAQ1a%@_5mCXZz=k4nQxT{1h{~lfwBAM{F6~;jDQh|1HnD~q&ykMff`=)JaW0{YB4^4Y|@;m>0Im5ma%|}`tR2Oy8 zEr_w8;;weNds^JZE_b`E5QUQhvLc9j`7qZJ6SaCQ%yo*|x2LGs)$97d8EeV9_~;4!ry7qFP-VH`zf zlR&yVMJJ~FQz(jRsltWoKtC>Gaj}O<^*gj$4?2B5U*m7>7!Rg~;1XVK!;D^B>hTw9 zEs-b}5+#uY;*u6ah>(aZaUFrEEQlQeG0KaA5RpQd8K%NUlx1Np$8r`|XgQd)K3o}| z8{b}utMu`!Si6Vy>AXFPx>JrEfmcZ@;{wN3MLJZw-B*2MXTPxRIcX_h%=er8aew zrEb>MEiC4nYJs8V{`^YM%udmgI!%8LOw$twztUJp(9QdE>?$7#g|r>|dt)BUEz;es zEN(O1V#9q}#lYPc49p&4;M=tJUeo}%onT2)+L+K)X;XJtYN@W4u~=@Z6^44IOod-l zD*VVv!1GKMZps6Il>~SGtM0U^yDW9Lu2!+Q$5g8g_0iS@FpCp#@9H3&d8vY@O$FRb z(3_-op%AsH`z&?8uGX+fn<^OU3|j$SQU>eBD`BLm0&b6&10En)n{>111Z>HJ`-&*(C=A| z@R+HG*uzG+pVt6Uu@3ME!K1uC*S-ePh;UYpE0Wl43(Q*fG2K4S;tA6}Y1kPv2@qfq ZhVrk&oofl0$VdR5BIxJi9Klw;RqHV~B|_z8IfO{1f~SOiYXqz8b=dKcl>8TiRIXqxYdn&pr8^Z*I=R=*Y?8^ERfy zT=xiKE}?k3T%O4xPqWK&SyrgRNpV>bM6Eo?Wpbig&!xCAQT_M&nq5<_|2vMG_Jowr zj!pOzvLgDk`3y~Sio#7e;)y6la zMw37%#zY5}#!485846+9+?`om$KrZ##2&M$wMleDqtWm$_2{1nP%P)?MsdRwZuG|7 zYD+rHB}GZ3@w}wQ5YjX)OI#+Nl?5>q7qh%52x%#an^H_lk5Z9Rb5ydpS4;x;~3W6hMv zwOKN3H#5{}8FsK(Z)Dh+s#ek7h5Mv-aFo=QC&}W08zkh56Wm47#mAgQXsbnSFx5s) z-Ob`2Lv7O4VE0RM^z}3HbN^HFn0-#Zr=JjXQ}pn$ZB~^}CzI+9{dqC3>H0KxFN^yO zw^?_mrDpiXxWILx0TymGg416G*h0}?lr|?6f);hZskUlr8;b!$J)o=YZJjXw!w1u1 zE9`s~fOkC}z#zr;f7N!2deBrmG&RIxr=fP~s`#}FmM-}q+8qF=JqQm5S^&E#_7tfF z6tWhz*Hnq70*i+XwNF>WeH9RKERi>F*TBaw4tPJmLhvxf{-T#1&tH|Lq{TgAx&xX! z$l_7MJ*K;FKDPqnst4Kaflqh6@Zp3L@HoYA(Q$(;Y2k{TS0r)55|}mZ)0%yT#Sz0ktJ{HcH-t*dVR?2t Wlr_~ue6kww9K{%)=O~`%{l5Y734RO! literal 0 HcmV?d00001 diff --git a/docs/notebooks/fid/results_mmd_poly_kid.npy b/docs/notebooks/fid/results_mmd_poly_kid.npy new file mode 100644 index 0000000000000000000000000000000000000000..1557e05b857daad7f704e9f28ee69881ab59bb09 GIT binary patch literal 1212 zcmbu-PfXKr6bJA(aDW5lk1*u_RA8GTf`THTDCqA8Oa_*J<3N^nE!%YK_PSp&ix5o= zL89T}%^0K6ixCqOFM8&{*&BDg7&RuwGuqNNmU;A-v`Jro`Mhsl-lO33*u$RS|UZ=?ggg<|UD3 z(~=7dLU{l%-h+i}_gArq!~MQMe#oKKme3Up21kEuXZIBtU|1}y&11K|5ZI6iwleezp&Gj?#G_GdhyG5?XSsgeJ;v(WZ0gQHX|-}+0;{(+M%nR9G*7SGlu%Ej>3Eag{?yEDWmYxNwJGz_rGelO+9O= z=XAA)!}F%vYpC&B3QwIBUTf-$QVQQ(6#E$V=ctP)CT;2kOTDP80~}s5)yszJ-Xaiu zOrYWofx%4z^&12S84l%4_Lh1xB}Z-U70bPPG}0SmpRqhlNfO?2EqH!CQ-)dri@M(`t#0{KF4g5h-`)L>tOWL&(FQWaUs s+5&TyJ*nGM98R0|jA7qv(pF)Iz)u&2=LacF7E_#M7!p*T;hfO^2hnSA{{R30 literal 0 HcmV?d00001 diff --git a/docs/notebooks/fid/results_mmd_rbf64.npy b/docs/notebooks/fid/results_mmd_rbf64.npy new file mode 100644 index 0000000000000000000000000000000000000000..b6ee852c8767989e57e36b9a70647116b5f4625a GIT binary patch literal 1212 zcmbu-TTc@~6bJArh?nA>n}}#pp#`ym7Tf`pOA*jYff01kkg_-0)pnP&oyAlkX<`s~ zHqlp}O?>geHxoXJ-$c*S*|^k4CpkOW*~xGIdorKWOSALKA+kZ`OxEl(-=qJithoU(46Nb#N|9}-NrqIY-GO5a>SC!-x7eif(z+3b<8lolWKMz; zd%HVulEJB1s#Xr++9rh4>GYG|I1V4FQ4aO$-U^)Ff-^C>3%8gBC2Lp)&$t%G&`oaI zmXgmDY~9FbjDl+Dx@l$MY>wm(qnyiOj`IvIV2)JXHeAeAj$cbagLix*gC_s@5okVo zJmJuy?j;*Z`LjgIA0%?{1?!_k$d&|`IJByAb1;)u>}P26GF)NME;3xrHF@X( zq6DzsNn~snXGel-93raRQib*f)DB<0?x{Bz+!Si3P=BKI9mQ`L^#{&hCFtT1Rpr*8 zs+!p>-l4x8i}`N1=k_qTCEVVMyPu}yU;-;i$=DPnFZ_bx5cNC99~GQ6w46D^nQe zFjA#%fL;iw4}A5Zr;aih6Y6+HjdWtZdQ{s`?7++o6ec(%t2+DEePdhMfII2Ck34sZ z!L)E^L}y>3HI21^#n2z8@R-Ass>`Ju37hiKrye@XAT7{2(E{)HXRziG#d&|2!aRos zRc;SngMz8Qwq3_EN`b)4x1V|TB7-GiKNt2SHlPQK9=arwpzwl&thx$^Wi|c>9iVNR literal 0 HcmV?d00001 diff --git a/docs/notebooks/fid/results_sw.npy b/docs/notebooks/fid/results_sw.npy new file mode 100644 index 0000000000000000000000000000000000000000..aeb3ddea28dac832ef1090da291e7d0fb921d8b7 GIT binary patch literal 1212 zcmbu--%Aux6bJBG&3>$!KT6HaS}U7fHB-}4yC*AuWNkZZlhsudmKC?)kstzMLIV z(~05H`bbiv^@;Qli%Xg&jk?1cjU?3+)6x;% zH?&P}i6^8eRVfRnRXv6hXYr&e4TTd)g${*jLZ*rmSEIN(#>LDi1u;EGp@>C#4u8%B z7RQ#x4>sc(WBghXOYHFkOaC6vF)Wj(JJ&95pC?7P=ZNq7GjjUQJb8EI3Bh#?%jHn1 zGgBnh=*Y0%&QM`w*dSu1m0@G7L`NSc?h*gfS@OQ)K6#q=fY7;Ff}0q6-rB<8jvZh7yvGNsJ$a_Ny+7<~v@rvL!hP)gqcdBwc8r65`&x?6&*Jrpj zBGy`No$2=f48Wt;0m!dD3YR_x;riVm;C6=f8EI3>NV7xTVXF;>+9+a^rS3GluVzAqM@yL3lVt0h<}N{HwM&)Lpi^+fZ9YY_rsMQ?06q!rNL3L0vA+i2^T10rxN@ z8R`fs35N={y4O%UMBHbo`%QKAw;K3S=!L5_KA1^+p)Z#QJixFs<7F?T&nMNW!|k%& zZo};n@u21Qn(n)o5je}q@ZvIs{%Z