-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
configs/conf_mmd_main_scaling_experiment_different_kernels.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
exp_log_name: "MMD_main_scaling_kernel_experiment" # optional but recommended | ||
|
||
# datasets to use | ||
data: ["toy_2d" , "random", "random"] | ||
augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',] | ||
|
||
# number of samples and dimensions | ||
n: [1000,1000,1000] #[10000, 10000, 10000] #samples Note that for main figure 10k | ||
d: [2, 10, 1000] # dimensions | ||
|
||
mmd_bandwidth: [[1, 5, 10],[1, 5, 10],[1, 5, 10]] | ||
|
||
# sample size experiments | ||
experiments: ["ScaleSampleSizeMMD", "ScaleSampleSizeMMD","ScaleSampleSizeMMD"] | ||
sample_size: [50, 100, 200, 500, 1000, 2000, 3000, 4000] | ||
runs: 5 # number of sample selection for errorbars | ||
|
||
# dimensionality experiments | ||
experiments_dim: ["ScaleDimMMD", "ScaleDimMMD", "ScaleDimMMD"] | ||
dim_sizes: [5, 10, 50, 100, 500, 1000] | ||
runs_dim: 5 # number of sample selection for errorbars | ||
|
||
# seed for reproducibility | ||
seed: 0 | ||
|
||
# for the reduced sample size experiments | ||
#sample_size: [8, 10, 20, 50, 80] | ||
#n: [500, 500, 500] |
311 changes: 311 additions & 0 deletions
311
docs/notebooks/mmd/MMD_scaling_experiment_different_kernels.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Test MMD bandwidth sensitivity in all scaling experiments" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load the configs and set up the plotting " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Running experiments...\n", | ||
"Seed: 0\n", | ||
"Experiments: ['ScaleSampleSizeMMD', 'ScaleSampleSizeMMD', 'ScaleSampleSizeMMD']\n", | ||
"Data: ['toy_2d', 'random', 'random']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import os\n", | ||
"import time\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from omegaconf import OmegaConf\n", | ||
"from torch.distributions import MultivariateNormal\n", | ||
"\n", | ||
"from labproject.data import DATASETS, DISTRIBUTIONS, get_dataset\n", | ||
"from labproject.experiments import *\n", | ||
"from labproject.plotting import cm2inch, generate_palette\n", | ||
"from labproject.utils import get_cfg, get_cfg_from_file, get_log_path, set_seed\n", | ||
"\n", | ||
"# inline plotting\n", | ||
"%matplotlib inline\n", | ||
"\n", | ||
"print(\"Running experiments...\")\n", | ||
"# load the config file\n", | ||
"cfg = get_cfg_from_file(\"conf_mmd_main_scaling_experiment\")\n", | ||
"cfg.running_user = 'MMD_main_scaling_kernel_experiment'\n", | ||
"seed = cfg.seed\n", | ||
"\n", | ||
"set_seed(seed)\n", | ||
"print(f\"Seed: {seed}\")\n", | ||
"print(f\"Experiments: {cfg.experiments}\") \n", | ||
"print(f\"Data: {cfg.data}\")\n", | ||
"\n", | ||
"# assert cfg.data is list\n", | ||
"assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n", | ||
" \n", | ||
"# setup colors and labels for plotting\n", | ||
"color_dict = {\"wasserstein\": \"#cc241d\",\n", | ||
" \"mmd\": \"#eebd35\",\n", | ||
" \"c2st\": \"#458588\",\n", | ||
" \"fid\": \"#8ec07c\", \n", | ||
" \"kl\": \"#8ec07c\"}\n", | ||
"\n", | ||
"col_map = {'ScaleSampleSizeKL':'kl', 'ScaleSampleSizeSW':'wasserstein',\n", | ||
" 'ScaleSampleSizeMMD':'mmd', 'ScaleSampleSizeC2ST':'c2st',\n", | ||
" 'ScaleSampleSizeFID':'fid', 'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n", | ||
" 'ScaleDimMMD':'mmd', 'ScaleDimC2ST':'c2st', 'ScaleGammaMMD':'mmd',\n", | ||
" 'ScaleDimFID':'fid',}\n", | ||
"\n", | ||
"mapping = {'ScaleSampleSizeKL':'KL', 'ScaleSampleSizeSW':'SW',\n", | ||
" 'ScaleSampleSizeMMD':'MMD', 'ScaleSampleSizeC2ST':'C2ST',\n", | ||
" 'ScaleSampleSizeFID':'FD', 'ScaleDimKL':'KL', 'ScaleDimSW':'SW',\n", | ||
" 'ScaleDimMMD':'MMD', 'ScaleDimC2ST':'C2ST',\n", | ||
" 'ScaleDimFID':'FD', 'ScaleGammaMMD':'MMD'}\n", | ||
"\n", | ||
"# dark and light colors for inter vs. intra comparisons \n", | ||
"col_dark = {}\n", | ||
"col_light = {}\n", | ||
"for e, exp_name in enumerate(cfg.experiments):\n", | ||
" col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n", | ||
" col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n", | ||
"for e, exp_name in enumerate(cfg.experiments_dim):\n", | ||
" col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n", | ||
" col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n", | ||
" \n", | ||
"color_list = [col_light, col_dark] # make this a list to account for true and shifted\n", | ||
"\n", | ||
"label_true = {}\n", | ||
"label_shift = {}\n", | ||
"for e, data_name in enumerate(cfg.data):\n", | ||
" label_true[data_name] = \"true\"\n", | ||
" label_shift[data_name] = \"generated\"\n", | ||
" \n", | ||
"label_list = [label_true, label_shift]\n", | ||
"label_list[1]['toy_2d'] = 'approx.'\n", | ||
"label_list[1]['random'] = 'shifted'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loop over the three datasets for respective MMD kernel implementations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# make comparison plots\n", | ||
"fig, axes = plt.subplots(3, 3, figsize=cm2inch((20, 10)), sharex='col')\n", | ||
"for ax in axes.flatten():\n", | ||
" # move spines outward\n", | ||
" ax.spines['bottom'].set_position(('outward', 4))\n", | ||
" ax.spines['left'].set_position(('outward', 4))\n", | ||
" ax.locator_params(nbins=3)\n", | ||
"\n", | ||
"# Loop over all three datasets\n", | ||
"for dd, ds in enumerate(cfg.data):\n", | ||
"\n", | ||
" # specify the bandwidth parameters to vary for each dataset\n", | ||
" # custom_values = np.linspace(cfg.val_min[dd], cfg.val_max[dd], cfg.val_step[dd]) # switch for linspacing\n", | ||
"\n", | ||
" dataset_fn = get_dataset(cfg.data[dd])\n", | ||
" # sample double the number of samples to ensure variability at the highest samples set size\n", | ||
" n_samples = cfg.n[dd] * cfg.runs\n", | ||
"\n", | ||
" # generate the ground truth and the two approximations inter and intra\n", | ||
" dataset_gt = dataset_fn(n_samples, cfg.d[dd])\n", | ||
" dataset_intra = dataset_fn(n_samples, cfg.d[dd])\n", | ||
"\n", | ||
" print(cfg.data[dd], n_samples, cfg.d[dd])\n", | ||
"\n", | ||
" # generate the inter dataset\n", | ||
" if cfg.data[dd] == \"toy_2d\":\n", | ||
" dataset_inter = MultivariateNormal(\n", | ||
" torch.mean(dataset_gt, axis=0).T, torch.cov(dataset_gt.T)\n", | ||
" ).sample((n_samples,))\n", | ||
" elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"mean_shift\":\n", | ||
" # shift the mean by 1 for all dimensions\n", | ||
" dataset_inter = dataset_fn(n_samples, cfg.d[dd]) + 1\n", | ||
" elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"one_dim_shift\":\n", | ||
" # just shift the first dimension by 1\n", | ||
" dataset_inter = dataset_fn(n_samples, cfg.d[dd])\n", | ||
" dataset_inter[:, 0] += 1 # just shift the mean of first dim by 1\n", | ||
"\n", | ||
" if dd < 2: # for the first two datasets, we compare sample sizes\n", | ||
" for e, exp_name in enumerate(cfg.experiments):\n", | ||
" experiment = globals()[exp_name]()\n", | ||
" ax = axes[e, dd]\n", | ||
" for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n", | ||
" assert (\n", | ||
" dataset_gt.shape == data_comp.shape\n", | ||
" ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n", | ||
" time_start = time.time()\n", | ||
" if mapping[exp_name] == \"MMD\":\n", | ||
" print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n", | ||
" output = experiment.run_experiment(\n", | ||
" dataset1=dataset_gt,\n", | ||
" dataset2=data_comp,\n", | ||
" sample_sizes=cfg.sample_size,\n", | ||
" nb_runs=cfg.runs,\n", | ||
" bandwidth=cfg.mmd_bandwidth[dd][e],\n", | ||
" )\n", | ||
" else:\n", | ||
" output = experiment.run_experiment(\n", | ||
" dataset1=dataset_gt,\n", | ||
" dataset2=data_comp,\n", | ||
" sample_sizes=cfg.sample_size,\n", | ||
" nb_runs=cfg.runs,\n", | ||
" )\n", | ||
" time_end = time.time()\n", | ||
" print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n", | ||
"\n", | ||
" log_path = get_log_path(\n", | ||
" cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n", | ||
" )\n", | ||
" os.makedirs(os.path.dirname(log_path), exist_ok=True)\n", | ||
" experiment.log_results(output, log_path)\n", | ||
" print(f\"Numerical results saved to {log_path}\")\n", | ||
"\n", | ||
" experiment.plot_experiment(\n", | ||
" *output,\n", | ||
" cfg.data[dd],\n", | ||
" ax=ax,\n", | ||
" color=color_list[dc][exp_name],\n", | ||
" label=label_list[dc][cfg.data[dd]],\n", | ||
" linestyle=\"-\" if dc == 0 else \"--\",\n", | ||
" lw=2,\n", | ||
" marker=\"o\",\n", | ||
" )\n", | ||
" ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n", | ||
" ax.set_xlabel(\"\")\n", | ||
" if mapping[exp_name] == \"C2ST\":\n", | ||
" ax.set_ylim([0.45, 1])\n", | ||
" ax.set_yticks([0.5, 1])\n", | ||
" ax.legend()\n", | ||
" else: # for the last dataset, we compare dimensions\n", | ||
" for e, exp_name in enumerate(cfg.experiments_dim):\n", | ||
" experiment = globals()[exp_name]()\n", | ||
" ax = axes[e, 2]\n", | ||
" ax.set_xscale(\"log\")\n", | ||
" for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n", | ||
" assert (\n", | ||
" dataset_gt.shape == data_comp.shape\n", | ||
" ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n", | ||
" time_start = time.time()\n", | ||
" if exp_name == \"ScaleDimMMD\":\n", | ||
" output = experiment.run_experiment(\n", | ||
" dataset1=dataset_gt,\n", | ||
" dataset2=data_comp,\n", | ||
" dataset_size=cfg.n[dd],\n", | ||
" dim_sizes=cfg.dim_sizes,\n", | ||
" nb_runs=cfg.runs_dim, # deterministic\n", | ||
" bandwidth=cfg.mmd_bandwidth[dd][e],\n", | ||
" )\n", | ||
" print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n", | ||
" else:\n", | ||
" output = experiment.run_experiment(\n", | ||
" dataset1=dataset_gt,\n", | ||
" dataset2=data_comp,\n", | ||
" dataset_size=cfg.n[dd],\n", | ||
" dim_sizes=cfg.dim_sizes,\n", | ||
" nb_runs=cfg.runs_dim,\n", | ||
" )\n", | ||
" time_end = time.time()\n", | ||
" print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n", | ||
"\n", | ||
" log_path = get_log_path(\n", | ||
" cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n", | ||
" )\n", | ||
" os.makedirs(os.path.dirname(log_path), exist_ok=True)\n", | ||
" experiment.log_results(output, log_path)\n", | ||
" print(f\"Numerical results saved to {log_path}\")\n", | ||
" experiment.plot_experiment(\n", | ||
" *output,\n", | ||
" cfg.data[dd],\n", | ||
" ax=ax,\n", | ||
" color=color_list[dc][exp_name],\n", | ||
" label=label_list[dc][cfg.data[dd]],\n", | ||
" linestyle=\"-\" if dc == 0 else \"--\",\n", | ||
" lw=2,\n", | ||
" marker=\"o\",\n", | ||
" )\n", | ||
" ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n", | ||
" ax.set_xlabel(\"\")\n", | ||
" if mapping[exp_name] == \"C2ST\":\n", | ||
" ax.set_ylim([0.45, 1])\n", | ||
" ax.set_yticks([0.5, 1])\n", | ||
"\n", | ||
" ax.legend()\n", | ||
"\n", | ||
"axes[-1, -1].set_xlabel(\"dimensions\")\n", | ||
"axes[-1, 0].set_xlabel(\"sample size\")\n", | ||
"axes[-1, 1].set_xlabel(\"sample size\")\n", | ||
"\n", | ||
"\n", | ||
"os.makedirs(\"./results/plots\", exist_ok=True)\n", | ||
"fig.tight_layout()\n", | ||
"fig.savefig(\n", | ||
" f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.png\", dpi=300\n", | ||
")\n", | ||
"fig.savefig(\n", | ||
" f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.pdf\", dpi=300\n", | ||
")\n", | ||
"\n", | ||
"print(\"Finished running experiments.\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "labproject", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |