This repository has been archived by the owner on Oct 11, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 18
/
train_and_export.py
82 lines (65 loc) · 3.07 KB
/
train_and_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
def my_freeze_graph(output_node_names, destination, name="frozen_model.pb"):
"""
Freeze the current graph alongside its weights (converted to constants) into a protobuf file.
:param output_node_names: The name of the output node names we are interested in
:param destination: Destination folder or remote service (eg. gs://)
:param name: Filename of the saved graph
:return:
"""
tf.keras.backend.set_learning_phase(0) # set inference phase
sess = tf.keras.backend.get_session()
input_graph_def = sess.graph.as_graph_def() # get graph def proto from keras session's graph
with sess.as_default():
# Convert variables into constants so they will be stored into the graph def
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names=output_node_names)
tf.train.write_graph(graph_or_graph_def=output_graph_def, logdir=destination, name=name, as_text=False)
tf.keras.backend.clear_session()
batch_size = 128
num_classes = 10
epochs = 12
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
print(f'input_shape={input_shape}')
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes, activation='softmax')])
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size, epochs=epochs, verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
print(f'input_layer_name={model.input.name}')
output_layer_name = model.output.name.split(':')[0]
print(f'output_layer_name={output_layer_name}')
# Now we save the current weights, in a more real scenario we would reload a checkpoint
# containing the best weights according to som measure of goodness.
my_freeze_graph([output_layer_name], destination='/tmp', name="frozen_model.pb")