-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_mlp_pytorch.py
126 lines (105 loc) · 3.74 KB
/
train_mlp_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
This module implements training and evaluation of a multi-layer perceptron in PyTorch.
You should fill in code into indicated sections.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import os
from mlp_pytorch import MLP
import cifar10_utils
# Default constants
DNN_HIDDEN_UNITS_DEFAULT = '100'
LEARNING_RATE_DEFAULT = 2e-3
MAX_STEPS_DEFAULT = 1500
BATCH_SIZE_DEFAULT = 200
EVAL_FREQ_DEFAULT = 100
NEG_SLOPE_DEFAULT = 0.02
# Directory in which cifar data is saved
DATA_DIR_DEFAULT = './cifar10/cifar-10-batches-py'
FLAGS = None
def accuracy(predictions, targets):
"""
Computes the prediction accuracy, i.e. the average of correct predictions
of the network.
Args:
predictions: 2D float array of size [batch_size, n_classes]
labels: 2D int array of size [batch_size, n_classes]
with one-hot encoding. Ground truth labels for
each sample in the batch
Returns:
accuracy: scalar float, the accuracy of predictions,
i.e. the average correct predictions over the whole batch
TODO:
Implement accuracy computation.
"""
########################
# PUT YOUR CODE HERE #
#######################
raise NotImplementedError
########################
# END OF YOUR CODE #
#######################
return accuracy
def train():
"""
Performs training and evaluation of MLP model.
TODO:
Implement training and evaluation of MLP model. Evaluate your model on the whole test set each eval_freq iterations.
"""
### DO NOT CHANGE SEEDS!
# Set the random seeds for reproducibility
np.random.seed(42)
## Prepare all functions
# Get number of units in each hidden layer specified in the string such as 100,100
if FLAGS.dnn_hidden_units:
dnn_hidden_units = FLAGS.dnn_hidden_units.split(",")
dnn_hidden_units = [int(dnn_hidden_unit_) for dnn_hidden_unit_ in dnn_hidden_units]
else:
dnn_hidden_units = []
# Get negative slope parameter for LeakyReLU
neg_slope = FLAGS.neg_slope
########################
# PUT YOUR CODE HERE #
#######################
raise NotImplementedError
########################
# END OF YOUR CODE #
#######################
def print_flags():
"""
Prints all entries in FLAGS variable.
"""
for key, value in vars(FLAGS).items():
print(key + ' : ' + str(value))
def main():
"""
Main function
"""
# Print all Flags to confirm parameter settings
print_flags()
if not os.path.exists(FLAGS.data_dir):
os.makedirs(FLAGS.data_dir)
# Run the training operation
train()
if __name__ == '__main__':
# Command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dnn_hidden_units', type = str, default = DNN_HIDDEN_UNITS_DEFAULT,
help='Comma separated list of number of units in each hidden layer')
parser.add_argument('--learning_rate', type = float, default = LEARNING_RATE_DEFAULT,
help='Learning rate')
parser.add_argument('--max_steps', type = int, default = MAX_STEPS_DEFAULT,
help='Number of steps to run trainer.')
parser.add_argument('--batch_size', type = int, default = BATCH_SIZE_DEFAULT,
help='Batch size to run trainer.')
parser.add_argument('--eval_freq', type=int, default=EVAL_FREQ_DEFAULT,
help='Frequency of evaluation on the test set')
parser.add_argument('--data_dir', type = str, default = DATA_DIR_DEFAULT,
help='Directory for storing input data')
parser.add_argument('--neg_slope', type=float, default=NEG_SLOPE_DEFAULT,
help='Negative slope parameter for LeakyReLU')
FLAGS, unparsed = parser.parse_known_args()
main()