Skip to content

Commit

Permalink
Merge pull request #25 from mackelab/franzi
Browse files Browse the repository at this point in the history
Updated dimensionality scaling plotting supp
  • Loading branch information
franzigrkn authored Feb 12, 2024
2 parents 9fb6ba7 + 36374a5 commit cefe018
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 77 deletions.
5 changes: 3 additions & 2 deletions configs/conf_dims.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
exp_log_name: "scaling_dims" # optional but recommended
data: ["multivariate_normal"]
experiments: ["ScaleDimSW", "ScaleDimMMD"]
n: [3000]
experiments: ["ScaleDimSW", "ScaleDimC2ST", "ScaleDimMMD"]
n: [1000]
d: [100]
distort: ["shift_one", "shift_all"]
dim_sizes: [5, 10, 15, 20, 25, 30, 40, 60, 80, 100]
mmd_bandwidth: [10., 10.]

seed: 1404
runs: 5
Expand Down
125 changes: 55 additions & 70 deletions docs/notebooks/scaling_dims_franzi.ipynb

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None):
idx = 0
shift = torch.zeros(n) + 1
samples[:, idx] = samples[:, idx] + shift
print(f"First 5 rows of dataset distorted: {samples[:5, :5]}")
return samples


Expand Down Expand Up @@ -527,7 +526,9 @@ def imagenet_validation_embedding(n, d=2048, device="cpu", save_path="data"):


@register_dataset("imagenet_conditional_model")
def imagenet_conditional_model(n, d=2048, label:Optional[int]=None, device="cpu", permute_if_no_label=True, save_path="data"):
def imagenet_conditional_model(
n, d=2048, label: Optional[int] = None, device="cpu", permute_if_no_label=True, save_path="data"
):
r"""Get the conditional model embeddings for ImageNet
Args:
Expand All @@ -549,9 +550,11 @@ def imagenet_conditional_model(n, d=2048, label:Optional[int]=None, device="cpu"
if label is not None:
conditional_embeddings = conditional_embeddings[label]
else:
conditional_embeddings = conditional_embeddings.flatten(0,1)
conditional_embeddings = conditional_embeddings.flatten(0, 1)
if permute_if_no_label:
conditional_embeddings = conditional_embeddings[torch.randperm(conditional_embeddings.shape[0])]
conditional_embeddings = conditional_embeddings[
torch.randperm(conditional_embeddings.shape[0])
]

max_n = conditional_embeddings.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from labproject.plotting import plot_scaling_metric_dimensionality, plot_scaling_metric_sample_size
from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
import pickle
import math


class Experiment:
Expand Down Expand Up @@ -55,7 +56,6 @@ def run_experiment(self, dataset1, dataset2, dataset_size, nb_runs=5, dim_sizes=
else torch.zeros_like(torch.tensor(dim_sizes))
)
final_distances = torch.tensor([torch.mean(d) for d in final_distances])
print(f"Final errors: {final_errors}")
return dim_sizes, final_distances, final_errors

def plot_experiment(
Expand Down

0 comments on commit cefe018

Please sign in to comment.