forked from argonne-lcf/ai-science-training-series
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorflow2_mnist.py
152 lines (134 loc) · 5.93 KB
/
tensorflow2_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import argparse
import numpy as np
import time
t0 = time.time()
parser = argparse.ArgumentParser(description='TensorFlow MNIST Example')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
help='input batch size for training (default: 256)')
parser.add_argument('--epochs', type=int, default=16, metavar='N',
help='number of epochs to train (default: 16)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--device', default='gpu',
help='Wheter this is running on cpu or gpu')
parser.add_argument('--num_inter', default=2, help='set number inter', type=int)
parser.add_argument('--num_intra', default=0, help='set number intra', type=int)
args = parser.parse_args()
if args.device == 'cpu':
tf.config.threading.set_intra_op_parallelism_threads(args.num_intra)
tf.config.threading.set_inter_op_parallelism_threads(args.num_inter)
else:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
#---------------------------------------------------
# Dataset
#---------------------------------------------------
(mnist_images, mnist_labels), (x_test, y_test) = \
tf.keras.datasets.mnist.load_data(path='mnist.npz')
dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
test_dset = tf.data.Dataset.from_tensor_slices(
(tf.cast(x_test[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(y_test, tf.int64))
)
nsamples = len(list(dataset))
ntests = len(list(test_dset))
# shuffle the dataset, with shuffle buffer to be 10000
dataset = dataset.repeat().shuffle(10000).batch(args.batch_size)
test_dset = test_dset.repeat().batch(args.batch_size)
#----------------------------------------------------
# Model
#----------------------------------------------------
mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation='softmax')
])
loss = tf.losses.SparseCategoricalCrossentropy()
opt = tf.optimizers.Adam(args.lr)
checkpoint_dir = './checkpoints/tf2_mnist'
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)
#------------------------------------------------------------------
# Training
#------------------------------------------------------------------
@tf.function
def training_step(images, labels):
with tf.GradientTape() as tape:
probs = mnist_model(images, training=True)
loss_value = loss(labels, probs)
pred = tf.math.argmax(probs, axis=1)
equality = tf.math.equal(pred, labels)
accuracy = tf.math.reduce_mean(tf.cast(equality, tf.float32))
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
return loss_value, accuracy
@tf.function
def validation_step(images, labels):
probs = mnist_model(images, training=False)
pred = tf.math.argmax(probs, axis=1)
equality = tf.math.equal(pred, labels)
accuracy = tf.math.reduce_mean(tf.cast(equality, tf.float32))
loss_value = loss(labels, probs)
return loss_value, accuracy
from tqdm import tqdm
t0 = time.time()
nstep = nsamples//args.batch_size
ntest_step = ntests//args.batch_size
metrics={}
metrics['train_acc'] = []
metrics['valid_acc'] = []
metrics['train_loss'] = []
metrics['valid_loss'] = []
metrics['time_per_epochs'] = []
for ep in range(args.epochs):
training_loss = 0.0
training_acc = 0.0
tt0 = time.time()
for batch, (images, labels) in enumerate(dataset.take(nstep)):
loss_value, acc = training_step(images, labels)
training_loss += loss_value/nstep
training_acc += acc/nstep
if batch % 100 == 0:
checkpoint.save(checkpoint_dir)
print('Epoch - %d, step #%06d/%06d\tLoss: %.6f' % (ep, batch, nstep, loss_value))
# Testing
test_acc = 0.0
test_loss = 0.0
for batch, (images, labels) in enumerate(test_dset.take(ntest_step)):
loss_value, acc = validation_step(images, labels)
test_acc += acc/ntest_step
test_loss += loss_value/ntest_step
tt1 = time.time()
print('E[%d], train Loss: %.6f, training Acc: %.3f, val loss: %.3f, val Acc: %.3f\t Time: %.3f seconds' % (ep, training_loss, training_acc, test_loss, test_acc, tt1 - tt0))
metrics['train_acc'].append(training_acc.numpy())
metrics['train_loss'].append(training_loss.numpy())
metrics['valid_acc'].append(test_acc.numpy())
metrics['valid_loss'].append(test_loss.numpy())
metrics['time_per_epochs'].append(tt1 - tt0)
checkpoint.save(checkpoint_dir)
np.savetxt("metrics.dat", np.array([metrics['train_acc'], metrics['train_loss'], metrics['valid_acc'], metrics['valid_loss'], metrics['time_per_epochs']]).transpose())
t1 = time.time()
print("Total training time: %s seconds" %(t1 - t0))