Skip to content

Commit

Permalink
refactor experiments.py, add gaussial_kl.py in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jaivardhankapoor committed Jan 31, 2024
1 parent 13f4a9d commit 3732831
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
23 changes: 4 additions & 19 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,13 @@ def scaling_kl_samples(dataset1, dataset2):
return dimensionality, distances


def run_metric_on_datasets(dataset1, dataset2, metric):
return metric(dataset1, dataset2)


class Experiment:
def __init__(self):
pass

def run_experiment(self, dataset1, dataset2, experiment_fn):
return experiment_fn(dataset1, dataset2)


if __name__ == "__main__":

experiment = Experiment()

experiment_results = {}
# for exp_name in ['scaling_sliced_wasserstein_samples', 'scaling_kl_samples']:
for exp_name in ["scaling_kl_samples"]:
for i_d1, dataset1 in enumerate([random_dataset(n=100000, d=100)]):
for i_d2, dataset2 in enumerate([random_dataset(n=100000, d=100)]):
experiment_fn = globals()[exp_name]
dimensionality, distances = experiment.run_experiment(
dataset1=dataset1, dataset2=dataset2, experiment_fn=experiment_fn
)
experiment_results[(exp_name, i_d1, i_d2)] = (dimensionality, distances)
# single plot
# plot_scaling_metric_dimensionality(dimensionality, distances, "Sliced Wasserstein", "Random Dataset")
plot_scaling_metric_dimensionality(dimensionality, distances, "KL", "Random Dataset")
10 changes: 5 additions & 5 deletions labproject/metrics/gaussian_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
def gaussian_kl_divergence(real_samples, fake_samples):
"""
Compute the KL divergence between Gaussian approximations of real and fake samples.
Dimensionality of the samples must be the same and >=2 (for covariance calculation).
Args:
real_samples (torch.Tensor): A tensor representing the real samples.
Expand All @@ -12,19 +13,18 @@ def gaussian_kl_divergence(real_samples, fake_samples):
Returns:
float: The KL divergence between the two Gaussian approximations.
"""
# Calculate mean and covariance of real and fake samples
# print(real_samples.shape, fake_samples.shape)
# calculate mean and covariance of real and fake samples
mu_real = real_samples.mean(dim=0)
mu_fake = fake_samples.mean(dim=0)
cov_real = torch.cov(real_samples.t())
cov_fake = torch.cov(fake_samples.t())

# Ensure the covariance matrices are invertible
# ensure the covariance matrices are invertible
eps = 1e-8
cov_real += torch.eye(cov_real.size(0)) * eps
cov_fake += torch.eye(cov_fake.size(0)) * eps

# Compute KL divergence
# compute KL divergence
inv_cov_fake = torch.inverse(cov_fake)
kl_div = 0.5 * (
torch.trace(inv_cov_fake @ cov_real)
Expand All @@ -37,7 +37,7 @@ def gaussian_kl_divergence(real_samples, fake_samples):


if __name__ == "__main__":
# Example usage
# example usage
real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional

Expand Down
41 changes: 32 additions & 9 deletions labproject/run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
from labproject.experiments import scaling_sliced_wasserstein_samples
from labproject.plotting import plot_scaling_metric_dimensionality

print("Running experiments...")
dimensionality, distances = scaling_sliced_wasserstein_samples()
plot_scaling_metric_dimensionality(
dimensionality, distances, "Sliced Wasserstein", "Random Dataset"
)
print("Finished running experiments.")
from labproject.experiments import *
from labproject.plotting import *
import time

if __name__ == "__main__":

print("Running experiments...")
experiment = Experiment()

experiment_results = {}
# for exp_name in ['scaling_sliced_wasserstein_samples', 'scaling_kl_samples']:
for exp_name in ["scaling_kl_samples"]:
time_start = time.time()
for i_d, dataset_pair in enumerate(
[
[random_dataset(n=100000, d=100), random_dataset(n=100000, d=100)],
]
):
dataset1, dataset2 = dataset_pair
experiment_fn = globals()[exp_name]
dimensionality, distances = experiment.run_experiment(
dataset1=dataset1, dataset2=dataset2, experiment_fn=experiment_fn
)
experiment_results[(exp_name, i_d)] = (dimensionality, distances)
time_end = time.time()
print(f"Experiment {exp_name} finished in {time_end - time_start}")

# single plot
# plot_scaling_metric_dimensionality(dimensionality, distances, "Sliced Wasserstein", "Random Dataset")
plot_scaling_metric_dimensionality(dimensionality, distances, "KL", "Random Dataset")

print("Finished running experiments.")

0 comments on commit 3732831

Please sign in to comment.