Skip to content

Commit

Permalink
Cleaned up dimensionality scaling notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
franzigrkn committed Mar 4, 2024
1 parent 3f72c2f commit de61db0
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 147 deletions.
20 changes: 15 additions & 5 deletions configs/conf_dims.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
exp_log_name: "scaling_dims" # optional but recommended

# datasets to use
data: ["multivariate_normal"]

# number of samples and dimensions
n: [5000]
d: [100]

# MMD bandwidth parameter
mmd_bandwidth: [10., 20., 20.]

# dimensionality scaling experiments
experiments: ["ScaleDimSW", "ScaleDimC2ST", "ScaleDimMMD"]
n: [10000]
d: [1000]
distort: ["shift_one", "shift_all", "increase_var"]
dim_sizes: [5, 10, 50, 100, 500, 1000]
mmd_bandwidth: [10., 10., 10.]
runs: 5 # number of sample selection for errorbars
distort: ["shift_one", "shift_all", "increase_var"]

# seed for reproducibility
seed: 1404
runs: 5


304 changes: 164 additions & 140 deletions docs/notebooks/scaling_dims_franzi.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,15 @@ def multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None):
), "The length of the means vector must be equal to the number of dimensions"
if vars is None:
if distort == "increase_var":
vars = torch.eye(dims) * 1.1
vars = torch.eye(dims) * 2
else:
vars = torch.eye(dims)
else:
assert (
len(vars) == dims
), "The length of the vars vector must be equal to the number of dimensions"
if distort == "increase_var":
vars = torch.diag(vars) * 1.1
vars = torch.diag(vars) * 2
else:
vars = torch.diag(vars)

Expand Down

0 comments on commit de61db0

Please sign in to comment.