Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Drawing Classifer Compatible with TensorFlow V2 Behavior (#3028)
Browse files Browse the repository at this point in the history
Rather than initialize TensorFlow variables to zero then assign the correct values,
variables must be initialized to the correct values at the begining.
  • Loading branch information
TobyRoseman authored Mar 4, 2020
1 parent 82208b3 commit 5393ebb
Showing 1 changed file with 27 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import turicreate.toolkits._tf_utils as _utils
import tensorflow.compat.v1 as _tf

# This toolkit is compatible with TensorFlow V2 behavior.
# However, until all toolkits are compatible, we must call `disable_v2_behavior()`.
_tf.disable_v2_behavior()


Expand All @@ -20,7 +22,6 @@ def __init__(self, net_params, batch_size, num_classes):
"""
Defines the TensorFlow model, loss, optimisation and accuracy. Then
loads the weights into the model.
"""
self.gpu_policy = _utils.TensorFlowGPUPolicy()
self.gpu_policy.start()
Expand Down Expand Up @@ -54,40 +55,34 @@ def init_drawing_classifier_graph(self, net_params):

# Weights
weights = {
"drawing_conv0_weight": _tf.Variable(
_tf.zeros([3, 3, 1, 16]), name="drawing_conv0_weight"
),
"drawing_conv1_weight": _tf.Variable(
_tf.zeros([3, 3, 16, 32]), name="drawing_conv1_weight"
),
"drawing_conv2_weight": _tf.Variable(
_tf.zeros([3, 3, 32, 64]), name="drawing_conv2_weight"
),
"drawing_dense0_weight": _tf.Variable(
_tf.zeros([576, 128]), name="drawing_dense0_weight"
),
"drawing_dense1_weight": _tf.Variable(
_tf.zeros([128, self.num_classes]), name="drawing_dense1_weight"
),
name: _tf.Variable(_utils.convert_conv2d_coreml_to_tf(net_params[name]), name=name)
for name in ("drawing_conv0_weight",
"drawing_conv1_weight",
"drawing_conv2_weight")
}
weights["drawing_dense1_weight"] = _tf.Variable(
_utils.convert_dense_coreml_to_tf(net_params["drawing_dense1_weight"]), name="drawing_dense1_weight"
)
"""
To make output of CoreML pool3 (NCHW) compatible with TF (NHWC).
Decompose FC weights to NCHW. Transpose to NHWC. Reshape back to FC.
"""
coreml_128_576 = net_params["drawing_dense0_weight"]
coreml_128_576 = _np.reshape(coreml_128_576, (128, 64, 3, 3))
coreml_128_576 = _np.transpose(coreml_128_576, (0, 2, 3, 1))
coreml_128_576 = _np.reshape(coreml_128_576, (128, 576))
weights["drawing_dense0_weight"] = _tf.Variable(
_np.transpose(coreml_128_576, (1, 0)), name="drawing_dense0_weight"
)

# Biases
biases = {
"drawing_conv0_bias": _tf.Variable(
_tf.zeros([16]), name="drawing_conv0_bias"
),
"drawing_conv1_bias": _tf.Variable(
_tf.zeros([32]), name="drawing_conv1_bias"
),
"drawing_conv2_bias": _tf.Variable(
_tf.zeros([64]), name="drawing_conv2_bias"
),
"drawing_dense0_bias": _tf.Variable(
_tf.zeros([128]), name="drawing_dense0_bias"
),
"drawing_dense1_bias": _tf.Variable(
_tf.zeros([self.num_classes]), name="drawing_dense1_bias"
),
name: _tf.Variable(net_params[name], name=name)
for name in ("drawing_conv0_bias",
"drawing_conv1_bias",
"drawing_conv2_bias",
"drawing_dense0_bias",
"drawing_dense1_bias")
}

conv_1 = _tf.nn.conv2d(
Expand Down Expand Up @@ -119,23 +114,19 @@ def init_drawing_classifier_graph(self, net_params):

# Flatten the data to a 1-D vector for the fully connected layer
fc1 = _tf.reshape(pool_3, (-1, 576))

fc1 = _tf.nn.xw_plus_b(
fc1,
weights=weights["drawing_dense0_weight"],
biases=biases["drawing_dense0_bias"],
)

fc1 = _tf.nn.relu(fc1)

out = _tf.nn.xw_plus_b(
fc1,
weights=weights["drawing_dense1_weight"],
biases=biases["drawing_dense1_bias"],
)
softmax_out = _tf.nn.softmax(out)

self.predictions = softmax_out
self.predictions = _tf.nn.softmax(out)

# Loss
self.cost = _tf.losses.softmax_cross_entropy(
Expand All @@ -153,60 +144,6 @@ def init_drawing_classifier_graph(self, net_params):
self.sess = _tf.Session()
self.sess.run(_tf.global_variables_initializer())

# Assign the initialised weights from C++ to tensorflow
layers = [
"drawing_conv0_weight",
"drawing_conv0_bias",
"drawing_conv1_weight",
"drawing_conv1_bias",
"drawing_conv2_weight",
"drawing_conv2_bias",
"drawing_dense0_weight",
"drawing_dense0_bias",
"drawing_dense1_weight",
"drawing_dense1_bias",
]

for key in layers:
if "bias" in key:
self.sess.run(
_tf.assign(
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
net_params[key],
)
)
else:
if "drawing_dense0_weight" in key:
"""
To make output of CoreML pool3 (NCHW) compatible with TF (NHWC).
Decompose FC weights to NCHW. Transpose to NHWC. Reshape back to FC.
"""
coreml_128_576 = net_params[key]
coreml_128_576 = _np.reshape(coreml_128_576, (128, 64, 3, 3))
coreml_128_576 = _np.transpose(coreml_128_576, (0, 2, 3, 1))
coreml_128_576 = _np.reshape(coreml_128_576, (128, 576))
self.sess.run(
_tf.assign(
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
_np.transpose(coreml_128_576, (1, 0)),
)
)
elif "dense" in key:
dense_weights = _utils.convert_dense_coreml_to_tf(net_params[key])
self.sess.run(
_tf.assign(
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
dense_weights,
)
)
else:
self.sess.run(
_tf.assign(
_tf.get_default_graph().get_tensor_by_name(key + ":0"),
_utils.convert_conv2d_coreml_to_tf(net_params[key]),
)
)

def __del__(self):
self.sess.close()
self.gpu_policy.stop()
Expand Down

0 comments on commit 5393ebb

Please sign in to comment.