Skip to content

Commit

Permalink
Merge branch 'main' into c2st
Browse files Browse the repository at this point in the history
  • Loading branch information
felixp8 authored Feb 5, 2024
2 parents e4b26fb + adeae75 commit 34a7809
Show file tree
Hide file tree
Showing 9 changed files with 495 additions and 27 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,6 @@ figures/

.idea/

secrets.py
secrets.py

results/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ conda activate labproject

# install labproject package with dependencies
python3 -m pip install --upgrade pip
cd labproject
pip install -e ".[dev,docs]"

# install pre-commit hooks for black auto-formatting
Expand Down
1 change: 1 addition & 0 deletions configs/conf_default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

exp_log_name: "default" # optional but recommended
data: "random"
experiments: ["ScaleDimKL"]
n: 10000
Expand Down
6 changes: 6 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Here all functions will be documented that are part of the public API of the lab
options:
heading_level: 4

### Gaussian Wasserstein

::: labproject.metrics.gaussian_squared_wasserstein
options:
heading_level: 4

### Sliced Wasserstein

::: labproject.metrics.sliced_wasserstein
Expand Down
37 changes: 26 additions & 11 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from metrics import sliced_wasserstein_distance, gaussian_kl_divergence
from plotting import plot_scaling_metric_dimensionality
from .metrics import sliced_wasserstein_distance, gaussian_kl_divergence
from .plotting import plot_scaling_metric_dimensionality
import pickle


class Experiment:
Expand All @@ -9,32 +10,46 @@ def __init__(self):

def run_experiment(self, metric, dataset1, dataset2):
raise NotImplementedError("Subclasses must implement this method")

def plot_experiment(self):
raise NotImplementedError("Subclasses must implement this method")


def log_results(self, results, log_path):
raise NotImplementedError("Subclasses must implement this method")


class ScaleDim(Experiment):

def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100):
self.metric_name = metric_name
self.metric_fn = metric_fn
self.dimensionality = list(range(min_dim, max_dim, step))
super().__init__()

def run_experiment(self, dataset1, dataset2):
distances = []
for d in self.dimensionality:
distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d]))
return self.dimensionality, distances

def plot_experiment(self, dimensionality, distances, dataset_name):
plot_scaling_metric_dimensionality(dimensionality, distances, self.metric_name, dataset_name)

plot_scaling_metric_dimensionality(
dimensionality, distances, self.metric_name, dataset_name
)

def log_results(self, results, log_path):
"""
Save the results to a file.
"""
with open(log_path, "wb") as f:
pickle.dump(results, f)


class ScaleDimKL(ScaleDim):
def __init__(self):
super().__init__("KL", gaussian_kl_divergence, min_dim=2)



class ScaleDimSW(ScaleDim):
def __init__(self):
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance)
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance)
Loading

0 comments on commit 34a7809

Please sign in to comment.