forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_student.py
212 lines (166 loc) · 8.9 KB
/
train_student.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import flags
from tensorflow.python.platform import app
import aggregation
import deep_cnn
import input
import metrics
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
flags.DEFINE_string('data_dir','/tmp','Temporary storage')
flags.DEFINE_string('train_dir','/tmp/train_dir','Where model chkpt are saved')
flags.DEFINE_string('teachers_dir','/tmp/train_dir',
'Directory where teachers checkpoints are stored.')
flags.DEFINE_integer('teachers_max_steps', 3000,
"""Number of steps teachers were ran.""")
flags.DEFINE_integer('max_steps', 3000, """Number of steps to run student.""")
flags.DEFINE_integer('nb_teachers', 10, """Teachers in the ensemble.""")
tf.app.flags.DEFINE_integer('stdnt_share', 1000,
"""Student share (last index) of the test data""")
flags.DEFINE_integer('lap_scale', 10,
"""Scale of the Laplacian noise added for privacy""")
flags.DEFINE_boolean('save_labels', False,
"""Dump numpy arrays of labels and clean teacher votes""")
flags.DEFINE_boolean('deeper', False, """Activate deeper CNN model""")
def ensemble_preds(dataset, nb_teachers, stdnt_data):
"""
Given a dataset, a number of teachers, and some input data, this helper
function queries each teacher for predictions on the data and returns
all predictions in a single array. (That can then be aggregated into
one single prediction per input using aggregation.py (cf. function
prepare_student_data() below)
:param dataset: string corresponding to mnist, cifar10, or svhn
:param nb_teachers: number of teachers (in the ensemble) to learn from
:param stdnt_data: unlabeled student training data
:return: 3d array (teacher id, sample id, probability per class)
"""
# Compute shape of array that will hold probabilities produced by each
# teacher, for each training point, and each output class
result_shape = (nb_teachers, len(stdnt_data), FLAGS.nb_labels)
# Create array that will hold result
result = np.zeros(result_shape, dtype=np.float32)
# Get predictions from each teacher
for teacher_id in xrange(nb_teachers):
# Compute path of checkpoint file for teacher model with ID teacher_id
if FLAGS.deeper:
ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt-' + str(FLAGS.teachers_max_steps - 1) #NOLINT(long-line)
else:
ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt-' + str(FLAGS.teachers_max_steps - 1) # NOLINT(long-line)
# Get predictions on our training data and store in result array
result[teacher_id] = deep_cnn.softmax_preds(stdnt_data, ckpt_path)
# This can take a while when there are a lot of teachers so output status
print("Computed Teacher " + str(teacher_id) + " softmax predictions")
return result
def prepare_student_data(dataset, nb_teachers, save=False):
"""
Takes a dataset name and the size of the teacher ensemble and prepares
training data for the student model, according to parameters indicated
in flags above.
:param dataset: string corresponding to mnist, cifar10, or svhn
:param nb_teachers: number of teachers (in the ensemble) to learn from
:param save: if set to True, will dump student training labels predicted by
the ensemble of teachers (with Laplacian noise) as npy files.
It also dumps the clean votes for each class (without noise) and
the labels assigned by teachers
:return: pairs of (data, labels) to be used for student training and testing
"""
assert input.create_dir_if_needed(FLAGS.train_dir)
# Load the dataset
if dataset == 'svhn':
test_data, test_labels = input.ld_svhn(test_only=True)
elif dataset == 'cifar10':
test_data, test_labels = input.ld_cifar10(test_only=True)
elif dataset == 'mnist':
test_data, test_labels = input.ld_mnist(test_only=True)
else:
print("Check value of dataset flag")
return False
# Make sure there is data leftover to be used as a test set
assert FLAGS.stdnt_share < len(test_data)
# Prepare [unlabeled] student training data (subset of test set)
stdnt_data = test_data[:FLAGS.stdnt_share]
# Compute teacher predictions for student training data
teachers_preds = ensemble_preds(dataset, nb_teachers, stdnt_data)
# Aggregate teacher predictions to get student training labels
if not save:
stdnt_labels = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale)
else:
# Request clean votes and clean labels as well
stdnt_labels, clean_votes, labels_for_dump = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale, return_clean_votes=True) #NOLINT(long-line)
# Prepare filepath for numpy dump of clean votes
filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_clean_votes_lap_' + str(FLAGS.lap_scale) + '.npy' # NOLINT(long-line)
# Prepare filepath for numpy dump of clean labels
filepath_labels = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_teachers_labels_lap_' + str(FLAGS.lap_scale) + '.npy' # NOLINT(long-line)
# Dump clean_votes array
with gfile.Open(filepath, mode='w') as file_obj:
np.save(file_obj, clean_votes)
# Dump labels_for_dump array
with gfile.Open(filepath_labels, mode='w') as file_obj:
np.save(file_obj, labels_for_dump)
# Print accuracy of aggregated labels
ac_ag_labels = metrics.accuracy(stdnt_labels, test_labels[:FLAGS.stdnt_share])
print("Accuracy of the aggregated labels: " + str(ac_ag_labels))
# Store unused part of test set for use as a test set after student training
stdnt_test_data = test_data[FLAGS.stdnt_share:]
stdnt_test_labels = test_labels[FLAGS.stdnt_share:]
if save:
# Prepare filepath for numpy dump of labels produced by noisy aggregation
filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_labels_lap_' + str(FLAGS.lap_scale) + '.npy' #NOLINT(long-line)
# Dump student noisy labels array
with gfile.Open(filepath, mode='w') as file_obj:
np.save(file_obj, stdnt_labels)
return stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels
def train_student(dataset, nb_teachers):
"""
This function trains a student using predictions made by an ensemble of
teachers. The student and teacher models are trained using the same
neural network architecture.
:param dataset: string corresponding to mnist, cifar10, or svhn
:param nb_teachers: number of teachers (in the ensemble) to learn from
:return: True if student training went well
"""
assert input.create_dir_if_needed(FLAGS.train_dir)
# Call helper function to prepare student data using teacher predictions
stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=True)
# Unpack the student dataset
stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
# Prepare checkpoint filename and path
if FLAGS.deeper:
ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student_deeper.ckpt' #NOLINT(long-line)
else:
ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt' # NOLINT(long-line)
# Start student training
assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
# Compute final checkpoint name for student (with max number of steps)
ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
# Compute student label predictions on remaining chunk of test set
student_preds = deep_cnn.softmax_preds(stdnt_test_data, ckpt_path_final)
# Compute teacher accuracy
precision = metrics.accuracy(student_preds, stdnt_test_labels)
print('Precision of student after training: ' + str(precision))
return True
def main(argv=None): # pylint: disable=unused-argument
# Run student training according to values specified in flags
assert train_student(FLAGS.dataset, FLAGS.nb_teachers)
if __name__ == '__main__':
app.run()