-
Notifications
You must be signed in to change notification settings - Fork 30
/
task_generator.py
executable file
·212 lines (196 loc) · 10.8 KB
/
task_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import numpy as np
import random
from tensorflow.python.platform import flags
from functools import reduce
from tqdm import tqdm
FLAGS = flags.FLAGS
from collections import defaultdict
from itertools import combinations, product
import os
from sklearn.cluster import KMeans
os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp' # default runs out of space for parallel processing
class TaskGenerator(object):
def __init__(self, num_classes, num_train_samples_per_class, num_samples_per_class):
self.num_classes = num_classes
self.num_train_samples_per_class = num_train_samples_per_class
self.num_samples_per_class = num_samples_per_class
def get_task(self, partition):
"""
A task consists of training and testing examples and labels. Instead of examples, we pass indices.
"""
train_indices, train_labels, test_indices, test_labels = [], [], [], []
classes = list(partition.keys())
sampled_classes = random.sample(classes, self.num_classes)
random.shuffle(sampled_classes) # the same classes given a different label ordering is a new task
for label, cls in zip(range(self.num_classes), sampled_classes):
class_indices = random.sample(partition[cls], self.num_samples_per_class)
train_indices.extend(class_indices[:self.num_train_samples_per_class])
test_indices.extend(class_indices[self.num_train_samples_per_class:])
train_labels.extend([label for i in range(self.num_train_samples_per_class)])
test_labels.extend([label for i in range(self.num_samples_per_class - self.num_train_samples_per_class)])
return train_indices, train_labels, test_indices, test_labels
def get_tasks(self, num_tasks, partitions):
tasks = []
for i_task in tqdm(range(num_tasks), desc='get_tasks'):
if num_tasks == len(partitions):
i_partition = i_task
else:
i_partition = random.sample(range(len(partitions)), 1)[0]
task = self.get_task(partition=partitions[i_partition])
tasks.append(task)
return tasks
def get_splits_hyperplanes(self, encodings, num_splits, margin):
"""
A split is a tuple of two zones, each of which is an array of indices whose encodings are more than margin away
from a random hyperplane.
"""
assert margin >= 0
n, d = encodings.shape
splits = []
good_splits, bad_splits = 0, 0
min_samples_per_zone = self.num_samples_per_class * 10
for i_split in tqdm(range(num_splits), desc='get_splits_hyperplanes'):
while True:
normal_vector = np.random.uniform(low=-1.0, high=1.0, size=(d,))
unit_normal_vector = normal_vector / np.linalg.norm(normal_vector)
if FLAGS.encoder == 'deepcluster': # whitened and normalized
point_on_plane = np.random.uniform(low=0.0, high=0.0, size=(d,))
else:
point_on_plane = np.random.uniform(low=-0.8, high=0.8, size=(d,))
relative_vector = encodings - point_on_plane # broadcasted
signed_distance = np.dot(relative_vector, unit_normal_vector)
below = np.where(signed_distance <= -margin)[0]
above = np.where(signed_distance >= margin)[0]
if len(below) < (min_samples_per_zone) or len(above) < (min_samples_per_zone):
bad_splits += 1
else:
splits.append((below, above))
good_splits += 1
break
print("Generated {} random splits, with {} failed splits.".format(num_splits, bad_splits))
return splits
def get_partitions_hyperplanes(self, encodings, num_splits, margin, num_partitions):
"""Create partitions where each element must be a certain margin away from all split-defining hyperplanes."""
splits = self.get_splits_hyperplanes(encodings=encodings, num_splits=num_splits, margin=margin)
bad_partitions = 0
partitions = []
for i in tqdm(range(num_partitions), desc='get_partitions_hyperplanes'):
partition, num_failed = self.get_partition_from_splits(splits)
partitions.append(partition)
bad_partitions += num_failed
if (i+1) % (num_partitions // 10) == 0:
tqdm.write('\t good partitions: {}, bad partitions: {}'.format(i + 1, bad_partitions))
print("Generated {} partitions respecting margin {}, with {} failed partitions.".format(num_partitions, margin, bad_partitions))
return partitions
def get_partition_from_splits(self, splits):
num_splits = len(splits)
splits_per_partition = np.int(np.ceil(np.log2(self.num_classes)))
num_failed = 0
while True:
which_splits = np.random.choice(num_splits, splits_per_partition, replace=False)
splits_for_this_partition = [splits[i] for i in which_splits]
partition = defaultdict(list)
num_big_enough_classes = 0
for i_class, above_or_belows in enumerate(product([0, 1], repeat=splits_per_partition)):
zones = [splits_for_this_partition[i][above_or_belows[i]] for i in range(splits_per_partition)]
indices = reduce(np.intersect1d, zones)
if len(indices) >= self.num_samples_per_class:
num_big_enough_classes += 1
partition[i_class].extend(indices.tolist())
if num_big_enough_classes >= self.num_classes:
break
else:
num_failed += 1
return partition, num_failed
def get_partitions_kmeans(self, encodings, train):
if FLAGS.on_pixels: # "encodings" are images
encodings = np.reshape(encodings, (encodings.shape[0], -1)).astype(np.float32)
mean = np.mean(encodings, axis=1)
var = np.var(encodings, axis=1)
encodings = ((encodings.T - mean.T) / np.sqrt(var.T + 10)).T # Coates and Ng, 2012
cov = np.cov(encodings, rowvar=False)
U, S, V = np.linalg.svd(cov)
epsilon = 1e-5
ZCA = np.dot(U, np.dot(np.diag(1.0/np.sqrt(S + epsilon)), U.T))
encodings = np.dot(ZCA, encodings.T).T
encodings_list = [encodings]
if train:
if FLAGS.scaled_encodings:
n_clusters_list = [FLAGS.num_clusters]
for i in range(FLAGS.num_partitions - 1):
weight_vector = np.random.uniform(low=0.0, high=1.0, size=encodings.shape[1])
encodings_list.append(np.multiply(encodings, weight_vector))
else:
n_clusters_list = [FLAGS.num_clusters] * FLAGS.num_partitions
else:
n_clusters_list = [FLAGS.num_clusters_test]
assert len(encodings_list) * len(n_clusters_list) == FLAGS.num_partitions
if FLAGS.dataset == 'celeba' or FLAGS.num_partitions != 1 or FLAGS.on_pixels:
n_init = 1 # so it doesn't take forever
else:
n_init = 10
init = 'k-means++'
print('Number of encodings: {}, number of n_clusters: {}, number of inits: '.format(len(encodings_list), len(n_clusters_list)), n_init)
kmeans_list = []
for n_clusters in tqdm(n_clusters_list, desc='get_partitions_kmeans_n_clusters'):
for encodings in tqdm(encodings_list, desc='get_partitions_kmeans_encodings'):
while True:
kmeans = KMeans(n_clusters=n_clusters, init=init, precompute_distances=True, n_jobs=40,
n_init=n_init, max_iter=3000).fit(encodings)
uniques, counts = np.unique(kmeans.labels_, return_counts=True)
num_big_enough_clusters = np.sum(counts > self.num_samples_per_class)
if num_big_enough_clusters > 0.75 * n_clusters or FLAGS.on_pixels:
break
else:
tqdm.write("Too few classes ({}) with greater than {} examples.".format(num_big_enough_clusters,
self.num_samples_per_class))
tqdm.write('Frequency: {}'.format(counts))
kmeans_list.append(kmeans)
partitions = []
for kmeans in kmeans_list:
partition = self.get_partition_from_labels(kmeans.labels_)
partitions.append(partition)
return partitions
def get_partition_from_labels(self, labels):
"""
Constructs a partition of the set of indices in labels, grouping indices according to their label.
:param labels: np.array of labels, whose i-th element is the label for the i-th datapoint
:return: a dictionary mapping class label to a list of indices that have that label
"""
partition = defaultdict(list)
for ind, label in enumerate(labels):
partition[label].append(ind)
self.clean_partition(partition)
return partition
def clean_partition(self, partition):
"""
Removes subsets that are too small from a partition.
"""
for cls in list(partition.keys()):
if len(partition[cls]) < self.num_samples_per_class:
del(partition[cls])
return partition
def get_celeba_task_pool(self, attributes, order=3, print_partition=None):
"""
Produces partitions: a list of dictionaries (key: 0 or 1, value: list of data indices), which is
compatible with the other methods of this class.
"""
num_pools = 0
partitions = []
from scipy.special import comb
for attr_comb in tqdm(combinations(range(attributes.shape[1]), order), desc='get_task_pool', total=comb(attributes.shape[1], order)):
for booleans in product(range(2), repeat=order-1):
booleans = (0,) + booleans # only the half of the cartesian products that start with 0
positive = np.where(np.all([attributes[:, attr] == i_booleans for (attr, i_booleans) in zip(attr_comb, booleans)], axis=0))[0]
if len(positive) < self.num_samples_per_class:
continue
negative = np.where(np.all([attributes[:, attr] == 1 - i_booleans for (attr, i_booleans) in zip(attr_comb, booleans)], axis=0))[0]
if len(negative) < self.num_samples_per_class:
continue
# inner_pool[booleans] = {0: list(negative), 1: list(positive)}
partitions.append({0: list(negative), 1: list(positive)})
num_pools += 1
if num_pools == print_partition:
print(attr_comb, booleans)
print('Generated {} task pools by using {} attributes from {} per pool'.format(num_pools, order, attributes.shape[1]))
return partitions