Skip to content

Commit

Permalink
include the energy and linear kernels as experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
augustes committed Jun 4, 2024
1 parent 56946e8 commit 8fdb0b8
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
gaussian_kl_divergence,
c2st_nn,
compute_rbf_mmd,
compute_linear_mmd,
compute_energy_mmd,
)
from labproject.plotting import plot_scaling_metric_dimensionality, plot_scaling_metric_sample_size
from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
Expand Down Expand Up @@ -115,6 +117,16 @@ def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_rbf_mmd, **kwargs)


class ScaleDimMMDenergy(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_energy_mmd, **kwargs)


class ScaleDimMMDlinear(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_linear_mmd, **kwargs)


"""class ScaleDimMMD(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("FID", compute_rbf_mmd, **kwargs)"""
Expand Down Expand Up @@ -229,6 +241,20 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
)


class ScaleSampleSizeMMDenergy(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
"MMD", compute_energy_mmd, min_samples=min_samples, sample_sizes=sample_sizes, **kwargs
)


class ScaleSampleSizeMMDlinear(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
"MMD", compute_linear_mmd, min_samples=min_samples, sample_sizes=sample_sizes, **kwargs
)


class ScaleSampleSizeFID(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
Expand Down

0 comments on commit 8fdb0b8

Please sign in to comment.