Skip to content

Commit

Permalink
got rid of hard coded n=1000 hyperparam exp, dataset1 randperm
Browse files Browse the repository at this point in the history
  • Loading branch information
augustes committed Feb 16, 2024
1 parent c48fe48 commit c5d3802
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 8 additions & 4 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def run_experiment(self, dataset1, dataset2, dataset_size, nb_runs=5, dim_sizes=
for d in dim_sizes:
# 3000 x 100
data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :d]
data2 = dataset2[torch.randperm(dataset1.size(0))[:n], :d]
data2 = dataset2[
torch.randperm(dataset2.size(0))[:n], :d
] # AS: changed from dataset1 to dataset2 in randperm
distances.append(self.metric_fn(data1, data2, **kwargs))
final_distances.append(distances)
final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
Expand Down Expand Up @@ -266,10 +268,10 @@ def __init__(
self.value_sizes = list(np.linspace(min_value, max_value, step))
super().__init__()

def run_experiment(self, dataset1, dataset2, nb_runs=5, value_sizes=None, **kwargs):
def run_experiment(self, dataset1, dataset2, nb_runs=5, n=10000, value_sizes=None, **kwargs):
final_distances = []
final_errors = []
n = 1000
# n = 1000 # AS: turned into argument
if value_sizes is None:
value_sizes = self.value_sizes
# print(value_sizes)
Expand All @@ -279,7 +281,9 @@ def run_experiment(self, dataset1, dataset2, nb_runs=5, value_sizes=None, **kwar
# print(v)
# 3000 x 100
data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :]
data2 = dataset2[torch.randperm(dataset1.size(0))[:n], :]
data2 = dataset2[
torch.randperm(dataset2.size(0))[:n], :
] # AS: changed from dataset1 to dataset2 in randperm
distances.append(self.metric_fn(data1, data2, v, **kwargs))

final_distances.append(distances)
Expand Down
12 changes: 12 additions & 0 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import seaborn as sns
import numpy as np
from matplotlib.ticker import MaxNLocator

####
# global plot params
Expand Down Expand Up @@ -35,6 +36,17 @@ def generate_palette(hex_color, n_colors=5, saturation="light"):
###


def ensure_zero_ytick(axis):
# Get current y-ticks from the axis
current_yticks = axis.get_yticks()

# Ensure 0 is included in y-ticks without duplication
if 0 not in current_yticks:
new_yticks = np.sort(np.append(current_yticks, 0))
axis.set_yticks(new_yticks)
axis.yaxis.set_major_locator(MaxNLocator(nbins=2))


def cm2inch(cm, INCH=2.54):
if isinstance(cm, tuple):
return tuple(i / INCH for i in cm)
Expand Down

0 comments on commit c5d3802

Please sign in to comment.