Skip to content

Commit

Permalink
add kernel scaling base copy
Browse files Browse the repository at this point in the history
  • Loading branch information
augustes committed Jun 4, 2024
1 parent 85b0a47 commit 1bfb253
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
28 changes: 28 additions & 0 deletions configs/conf_mmd_main_scaling_experiment_different_kernels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
exp_log_name: "MMD_main_scaling_kernel_experiment" # optional but recommended

# datasets to use
data: ["toy_2d" , "random", "random"]
augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',]

# number of samples and dimensions
n: [1000,1000,1000] #[10000, 10000, 10000] #samples Note that for main figure 10k
d: [2, 10, 1000] # dimensions

mmd_bandwidth: [[1, 5, 10],[1, 5, 10],[1, 5, 10]]

# sample size experiments
experiments: ["ScaleSampleSizeMMD", "ScaleSampleSizeMMD","ScaleSampleSizeMMD"]
sample_size: [50, 100, 200, 500, 1000, 2000, 3000, 4000]
runs: 5 # number of sample selection for errorbars

# dimensionality experiments
experiments_dim: ["ScaleDimMMD", "ScaleDimMMD", "ScaleDimMMD"]
dim_sizes: [5, 10, 50, 100, 500, 1000]
runs_dim: 5 # number of sample selection for errorbars

# seed for reproducibility
seed: 0

# for the reduced sample size experiments
#sample_size: [8, 10, 20, 50, 80]
#n: [500, 500, 500]
311 changes: 311 additions & 0 deletions docs/notebooks/mmd/MMD_scaling_experiment_different_kernels.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test MMD bandwidth sensitivity in all scaling experiments"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the configs and set up the plotting "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running experiments...\n",
"Seed: 0\n",
"Experiments: ['ScaleSampleSizeMMD', 'ScaleSampleSizeMMD', 'ScaleSampleSizeMMD']\n",
"Data: ['toy_2d', 'random', 'random']\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from omegaconf import OmegaConf\n",
"from torch.distributions import MultivariateNormal\n",
"\n",
"from labproject.data import DATASETS, DISTRIBUTIONS, get_dataset\n",
"from labproject.experiments import *\n",
"from labproject.plotting import cm2inch, generate_palette\n",
"from labproject.utils import get_cfg, get_cfg_from_file, get_log_path, set_seed\n",
"\n",
"# inline plotting\n",
"%matplotlib inline\n",
"\n",
"print(\"Running experiments...\")\n",
"# load the config file\n",
"cfg = get_cfg_from_file(\"conf_mmd_main_scaling_experiment\")\n",
"cfg.running_user = 'MMD_main_scaling_kernel_experiment'\n",
"seed = cfg.seed\n",
"\n",
"set_seed(seed)\n",
"print(f\"Seed: {seed}\")\n",
"print(f\"Experiments: {cfg.experiments}\") \n",
"print(f\"Data: {cfg.data}\")\n",
"\n",
"# assert cfg.data is list\n",
"assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n",
" \n",
"# setup colors and labels for plotting\n",
"color_dict = {\"wasserstein\": \"#cc241d\",\n",
" \"mmd\": \"#eebd35\",\n",
" \"c2st\": \"#458588\",\n",
" \"fid\": \"#8ec07c\", \n",
" \"kl\": \"#8ec07c\"}\n",
"\n",
"col_map = {'ScaleSampleSizeKL':'kl', 'ScaleSampleSizeSW':'wasserstein',\n",
" 'ScaleSampleSizeMMD':'mmd', 'ScaleSampleSizeC2ST':'c2st',\n",
" 'ScaleSampleSizeFID':'fid', 'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n",
" 'ScaleDimMMD':'mmd', 'ScaleDimC2ST':'c2st', 'ScaleGammaMMD':'mmd',\n",
" 'ScaleDimFID':'fid',}\n",
"\n",
"mapping = {'ScaleSampleSizeKL':'KL', 'ScaleSampleSizeSW':'SW',\n",
" 'ScaleSampleSizeMMD':'MMD', 'ScaleSampleSizeC2ST':'C2ST',\n",
" 'ScaleSampleSizeFID':'FD', 'ScaleDimKL':'KL', 'ScaleDimSW':'SW',\n",
" 'ScaleDimMMD':'MMD', 'ScaleDimC2ST':'C2ST',\n",
" 'ScaleDimFID':'FD', 'ScaleGammaMMD':'MMD'}\n",
"\n",
"# dark and light colors for inter vs. intra comparisons \n",
"col_dark = {}\n",
"col_light = {}\n",
"for e, exp_name in enumerate(cfg.experiments):\n",
" col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
" col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
"for e, exp_name in enumerate(cfg.experiments_dim):\n",
" col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
" col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
" \n",
"color_list = [col_light, col_dark] # make this a list to account for true and shifted\n",
"\n",
"label_true = {}\n",
"label_shift = {}\n",
"for e, data_name in enumerate(cfg.data):\n",
" label_true[data_name] = \"true\"\n",
" label_shift[data_name] = \"generated\"\n",
" \n",
"label_list = [label_true, label_shift]\n",
"label_list[1]['toy_2d'] = 'approx.'\n",
"label_list[1]['random'] = 'shifted'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loop over the three datasets for respective MMD kernel implementations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make comparison plots\n",
"fig, axes = plt.subplots(3, 3, figsize=cm2inch((20, 10)), sharex='col')\n",
"for ax in axes.flatten():\n",
" # move spines outward\n",
" ax.spines['bottom'].set_position(('outward', 4))\n",
" ax.spines['left'].set_position(('outward', 4))\n",
" ax.locator_params(nbins=3)\n",
"\n",
"# Loop over all three datasets\n",
"for dd, ds in enumerate(cfg.data):\n",
"\n",
" # specify the bandwidth parameters to vary for each dataset\n",
" # custom_values = np.linspace(cfg.val_min[dd], cfg.val_max[dd], cfg.val_step[dd]) # switch for linspacing\n",
"\n",
" dataset_fn = get_dataset(cfg.data[dd])\n",
" # sample double the number of samples to ensure variability at the highest samples set size\n",
" n_samples = cfg.n[dd] * cfg.runs\n",
"\n",
" # generate the ground truth and the two approximations inter and intra\n",
" dataset_gt = dataset_fn(n_samples, cfg.d[dd])\n",
" dataset_intra = dataset_fn(n_samples, cfg.d[dd])\n",
"\n",
" print(cfg.data[dd], n_samples, cfg.d[dd])\n",
"\n",
" # generate the inter dataset\n",
" if cfg.data[dd] == \"toy_2d\":\n",
" dataset_inter = MultivariateNormal(\n",
" torch.mean(dataset_gt, axis=0).T, torch.cov(dataset_gt.T)\n",
" ).sample((n_samples,))\n",
" elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"mean_shift\":\n",
" # shift the mean by 1 for all dimensions\n",
" dataset_inter = dataset_fn(n_samples, cfg.d[dd]) + 1\n",
" elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"one_dim_shift\":\n",
" # just shift the first dimension by 1\n",
" dataset_inter = dataset_fn(n_samples, cfg.d[dd])\n",
" dataset_inter[:, 0] += 1 # just shift the mean of first dim by 1\n",
"\n",
" if dd < 2: # for the first two datasets, we compare sample sizes\n",
" for e, exp_name in enumerate(cfg.experiments):\n",
" experiment = globals()[exp_name]()\n",
" ax = axes[e, dd]\n",
" for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
" assert (\n",
" dataset_gt.shape == data_comp.shape\n",
" ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n",
" time_start = time.time()\n",
" if mapping[exp_name] == \"MMD\":\n",
" print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n",
" output = experiment.run_experiment(\n",
" dataset1=dataset_gt,\n",
" dataset2=data_comp,\n",
" sample_sizes=cfg.sample_size,\n",
" nb_runs=cfg.runs,\n",
" bandwidth=cfg.mmd_bandwidth[dd][e],\n",
" )\n",
" else:\n",
" output = experiment.run_experiment(\n",
" dataset1=dataset_gt,\n",
" dataset2=data_comp,\n",
" sample_sizes=cfg.sample_size,\n",
" nb_runs=cfg.runs,\n",
" )\n",
" time_end = time.time()\n",
" print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
"\n",
" log_path = get_log_path(\n",
" cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n",
" )\n",
" os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
" experiment.log_results(output, log_path)\n",
" print(f\"Numerical results saved to {log_path}\")\n",
"\n",
" experiment.plot_experiment(\n",
" *output,\n",
" cfg.data[dd],\n",
" ax=ax,\n",
" color=color_list[dc][exp_name],\n",
" label=label_list[dc][cfg.data[dd]],\n",
" linestyle=\"-\" if dc == 0 else \"--\",\n",
" lw=2,\n",
" marker=\"o\",\n",
" )\n",
" ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n",
" ax.set_xlabel(\"\")\n",
" if mapping[exp_name] == \"C2ST\":\n",
" ax.set_ylim([0.45, 1])\n",
" ax.set_yticks([0.5, 1])\n",
" ax.legend()\n",
" else: # for the last dataset, we compare dimensions\n",
" for e, exp_name in enumerate(cfg.experiments_dim):\n",
" experiment = globals()[exp_name]()\n",
" ax = axes[e, 2]\n",
" ax.set_xscale(\"log\")\n",
" for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
" assert (\n",
" dataset_gt.shape == data_comp.shape\n",
" ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n",
" time_start = time.time()\n",
" if exp_name == \"ScaleDimMMD\":\n",
" output = experiment.run_experiment(\n",
" dataset1=dataset_gt,\n",
" dataset2=data_comp,\n",
" dataset_size=cfg.n[dd],\n",
" dim_sizes=cfg.dim_sizes,\n",
" nb_runs=cfg.runs_dim, # deterministic\n",
" bandwidth=cfg.mmd_bandwidth[dd][e],\n",
" )\n",
" print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n",
" else:\n",
" output = experiment.run_experiment(\n",
" dataset1=dataset_gt,\n",
" dataset2=data_comp,\n",
" dataset_size=cfg.n[dd],\n",
" dim_sizes=cfg.dim_sizes,\n",
" nb_runs=cfg.runs_dim,\n",
" )\n",
" time_end = time.time()\n",
" print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
"\n",
" log_path = get_log_path(\n",
" cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n",
" )\n",
" os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
" experiment.log_results(output, log_path)\n",
" print(f\"Numerical results saved to {log_path}\")\n",
" experiment.plot_experiment(\n",
" *output,\n",
" cfg.data[dd],\n",
" ax=ax,\n",
" color=color_list[dc][exp_name],\n",
" label=label_list[dc][cfg.data[dd]],\n",
" linestyle=\"-\" if dc == 0 else \"--\",\n",
" lw=2,\n",
" marker=\"o\",\n",
" )\n",
" ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n",
" ax.set_xlabel(\"\")\n",
" if mapping[exp_name] == \"C2ST\":\n",
" ax.set_ylim([0.45, 1])\n",
" ax.set_yticks([0.5, 1])\n",
"\n",
" ax.legend()\n",
"\n",
"axes[-1, -1].set_xlabel(\"dimensions\")\n",
"axes[-1, 0].set_xlabel(\"sample size\")\n",
"axes[-1, 1].set_xlabel(\"sample size\")\n",
"\n",
"\n",
"os.makedirs(\"./results/plots\", exist_ok=True)\n",
"fig.tight_layout()\n",
"fig.savefig(\n",
" f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.png\", dpi=300\n",
")\n",
"fig.savefig(\n",
" f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.pdf\", dpi=300\n",
")\n",
"\n",
"print(\"Finished running experiments.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "labproject",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 1bfb253

Please sign in to comment.