Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lappalainenj committed Feb 16, 2024
2 parents c128df8 + ee79b34 commit 7401404
Show file tree
Hide file tree
Showing 6 changed files with 1,128 additions and 2 deletions.
21 changes: 21 additions & 0 deletions configs/conf_scale_gamma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
exp_log_name: "MMD_scale_gamma" # optional but recommended
data: ["toy_2d" , "random", "random"]

experiments: ["ScaleGammaMMD", "ScaleGammaMMD", "ScaleGammaMMD"]
dim_sizes: [10, 100]
sample_size: [10000]

n: [10000, 10000, 10000]
d: [2, 10, 100]
val_min: [0.1, 1, 1]
val_max: [5, 20, 50]
val_step: [10, 10, 10]
augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',]

seed: 0
runs: 5



value_sizes: [[0.1, 0.2, 0.5, 0.75, 0.9, 1.0, 1.25, 1.5, 2, 2.5, 3.0, 4.0, 5.0],
[0.1,1,1.5,2,2.5,3,4,5,6,8,10,12,14,16,18,20], [1,3,5,6,7,8,9,10,12,15,20,25,30,35,40]]
2 changes: 1 addition & 1 deletion configs/conf_scale_reduced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',]

seed: 0
runs: 5
runs_dim: 1
runs_dim: 5
369 changes: 369 additions & 0 deletions docs/notebooks/gamma_scaling_MMD.ipynb

Large diffs are not rendered by default.

641 changes: 641 additions & 0 deletions docs/notebooks/mode_comparison.ipynb

Large diffs are not rendered by default.

85 changes: 84 additions & 1 deletion labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
import pickle
import math
import numpy as np


class Experiment:
Expand Down Expand Up @@ -46,7 +47,9 @@ def run_experiment(self, dataset1, dataset2, dataset_size, nb_runs=5, dim_sizes=
for d in dim_sizes:
# 3000 x 100
data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :d]
data2 = dataset2[torch.randperm(dataset1.size(0))[:n], :d]
data2 = dataset2[
torch.randperm(dataset2.size(0))[:n], :d
] # AS: changed from dataset1 to dataset2 in randperm
distances.append(self.metric_fn(data1, data2, **kwargs))
final_distances.append(distances)
final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
Expand Down Expand Up @@ -251,3 +254,83 @@ def log_results(self, fid_metric, log_path):

def plot_experiment(self, fid_metric, dataset_name):
pass


class ScaleHyperparameter(Experiment):
def __init__(
self, metric_name, metric_fn, value_sizes=None, min_value=0.2, max_value=50, step=10
):
self.metric_name = metric_name
self.metric_fn = metric_fn
if value_sizes is not None:
self.value_sizes = value_sizes
else:
self.value_sizes = list(np.linspace(min_value, max_value, step))
super().__init__()

def run_experiment(self, dataset1, dataset2, nb_runs=5, n=10000, value_sizes=None, **kwargs):
final_distances = []
final_errors = []
# n = 1000 # AS: turned into argument
if value_sizes is None:
value_sizes = self.value_sizes
# print(value_sizes)
for idx in range(nb_runs):
distances = []
for v in value_sizes:
# print(v)
# 3000 x 100
data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :]
data2 = dataset2[
torch.randperm(dataset2.size(0))[:n], :
] # AS: changed from dataset1 to dataset2 in randperm
distances.append(self.metric_fn(data1, data2, v, **kwargs))

final_distances.append(distances)

final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
final_errors = (
torch.tensor([torch.std(d) for d in final_distances])
if nb_runs > 1
else torch.zeros_like(torch.tensor(value_sizes))
)
final_distances = torch.tensor([torch.mean(d) for d in final_distances])
return value_sizes, final_distances, final_errors

def plot_experiment(
self,
value_sizes,
distances,
errors,
dataset_name,
ax=None,
color=None,
label=None,
linestyle="-",
**kwargs,
):

plot_scaling_metric_dimensionality(
value_sizes,
distances,
errors,
self.metric_name,
dataset_name,
ax=ax,
color=color,
label=label,
linestyle=linestyle,
**kwargs,
)

def log_results(self, results, log_path):
"""
Save the results to a file.
"""
with open(log_path, "wb") as f:
pickle.dump(results, f)


class ScaleGammaMMD(ScaleHyperparameter):
def __init__(self, **kwargs):
super().__init__("MMD", compute_rbf_mmd, **kwargs)
12 changes: 12 additions & 0 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import seaborn as sns
import numpy as np
from matplotlib.ticker import MaxNLocator

####
# global plot params
Expand Down Expand Up @@ -35,6 +36,17 @@ def generate_palette(hex_color, n_colors=5, saturation="light"):
###


def ensure_zero_ytick(axis):
# Get current y-ticks from the axis
current_yticks = axis.get_yticks()

# Ensure 0 is included in y-ticks without duplication
if 0 not in current_yticks:
new_yticks = np.sort(np.append(current_yticks, 0))
axis.set_yticks(new_yticks)
axis.yaxis.set_major_locator(MaxNLocator(nbins=2))


def cm2inch(cm, INCH=2.54):
if isinstance(cm, tuple):
return tuple(i / INCH for i in cm)
Expand Down

0 comments on commit 7401404

Please sign in to comment.