diff --git a/E1_TPU_Sample/README.md b/E1_TPU_Sample/README.md index e708454..3663ea6 100644 --- a/E1_TPU_Sample/README.md +++ b/E1_TPU_Sample/README.md @@ -6,7 +6,7 @@ ### Cloud TPU **TPU Type:** v2.8 -**Tensorflow Version:** Nightly +**Tensorflow Version:** 1.14 ### Cloud VM @@ -17,7 +17,7 @@ Launching Instance and VM --------------------------- - Open Google Cloud Shell -- `ctpu up -tf-version nightly` +- `ctpu up -tf-version 1.14` - If cloud bucket is not setup automatically, create a cloud storage bucket with the same name as TPU and the VM - enable HTTP traffic for the VM instance @@ -26,35 +26,6 @@ with the same name as TPU and the VM - `pip3 install -r requirements.txt` - `export CTPU_NAME=` -Chaning Tensorflow Source Code For Support to Cloud TPU: --------------------------------------------------------- -TPU is not Officially Supported for Tensorflow 2.0, so it is not exposed in the Public API. -However in the code, the python files containing the required modules are imported explicitly. -There's a small bug in `CrossShardOptimizer` which tries to use OptimizerV1 and all Optimizers -available in the Public API are in V2. To support V2 Optimizers, a small Code Fragment is needed -to be changed in CrossShardOptimizer's `apply_gradients(...)` function. -To do that -- Browse (`cd`) to the installation directory of tensorflow. - -**To find the installation directory:** -```python3 ->>> import os ->>> import tensorflow as tf ->>> print(os.path.dirname(str(tf).split(" ")[-1][1:])) -``` - -- `cd` to `python/tpu` inside the installation directory -- open `tpu_optimizer.py` in an editor -- change line no. 173 (For Tensorflow 2.0 Beta) -**From** -```python3 - return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) -``` -**To** -```python3 - return self._opt.apply_gradients(summed_grads_and_vars, name=name) -``` -- Save Changes Running Tensorboard: ---------------------- @@ -74,11 +45,30 @@ To view Tensorboard, Browse to the Public IP of the VM Instance Running the Code: ---------------------- +#### Train The Model + ```bash $ python3 image_retraining_tpu.py --tpu $CTPU_NAME --use_tpu \ ---model_dir gs://$CTPU_NAME/model_dir \ ---data_dir gs://$CTPU_NAME/data_dir \ ---batch_size 16 \ ---iterations 4 \ +--modeldir gs://$CTPU_NAME/modeldir \ +--datadir gs://$CTPU_NAME/datadir \ +--logdir gs://$CTPU_NAME/logdir \ +--num_steps 2000 \ --dataset horses_or_humans ``` +Training Saves one single checkpoint at the end of training. This checkpoint can be loaded up +later to export a SavedModel from it. + +#### Export Model + +```bash +$ python3 image_retraining_tpu.py --tpu $CTPU_NAME --use_tpu \ +--modeldir gs://$CTPU_NAME/modeldir \ +--datadir gs://$CTPU_NAME/datadir \ +--logdir gs://$CTPU_NAME/logdir \ +--dataset horses_or_humans \ +--export_only \ +--export_path modeldir/model +``` +Exporting SavedModel of trained model +---------------------------- +The trained model gets saved at `gs://$CTPU_NAME/modeldir/model` by default if the path is not explicitly stated using `--export_path` diff --git a/E1_TPU_Sample/image_retraining_tpu.py b/E1_TPU_Sample/image_retraining_tpu.py deleted file mode 100644 index f66d7a0..0000000 --- a/E1_TPU_Sample/image_retraining_tpu.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -from functools import partial - -import tensorflow as tf -import tensorflow_datasets as tfds -import tensorflow_hub as hub -from tensorflow.python.tpu import tpu_estimator -from tensorflow.python.tpu import tpu_optimizer -from tensorflow.python.tpu import tpu_config -from absl import flags, app - -flags.DEFINE_string("tpu", None, "TPU Address") -flags.DEFINE_integer("iterations", 2, "Number of Itertions") -flags.DEFINE_integer("batch_size", 16, "Size of each Batch") -flags.DEFINE_float("learning_rate", 1e-3, "Learning Rate") -flags.DEFINE_boolean("use_tpu", True, " Use TPU") -flags.DEFINE_boolean("use_compat", True, "Use OptimizerV1 from compat module") -flags.DEFINE_integer( - "max_steps", - 1000, - "Maximum Number of Steps for TPU Estimator") -flags.DEFINE_string( - "model_dir", - "model_dir/", - "Directory to Save the Models and Checkpoint") -flags.DEFINE_string( - "dataset", - "horses_or_humans", - "TFDS Dataset Name. IMAGE Dimension should be >= 224, channel=3") -flags.DEFINE_string("data_dir", None, "Directory to Save Data to") -flags.DEFINE_string("infer", None, "Dummy image file to infer") - -FLAGS = flags.FLAGS -NUM_CLASSES = None - - -def resize_and_scale(image, label): - image = tf.image.resize(image, size=[224, 224]) - image = tf.cast(image, tf.float32) - image = image / tf.reduce_max(tf.gather(image, 0)) - return image, label - - -def input_(mode, batch_size, iterations, **kwargs): - global NUM_CLASSES - dataset, info = tfds.load( - kwargs["dataset"], - as_supervised=True, - split="train" if mode == tf.estimator.ModeKeys.TRAIN else "test", - with_info=True, - data_dir=kwargs['data_dir'] - ) - NUM_CLASSES = info.features['label'].num_classes - dataset = dataset.map(resize_and_scale).shuffle( - 1000).repeat(iterations).batch(batch_size, drop_remainder=True) - return dataset - - -def model_fn(features, labels, mode, params): - global NUM_CLASSES - assert NUM_CLASSES is not None - model = tf.keras.Sequential([ - hub.KerasLayer("https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4", - output_shape=[2048], - trainable=False - ), - tf.keras.layers.Dense(NUM_CLASSES, activation="softmax") - ]) - optimizer = None - if mode == tf.estimator.ModeKeys.TRAIN: - if not params["use_compat"]: - optimizer = tf.optimizers.Adam(params["learning_rate"]) - else: - optimizer = tf.compat.v1.train.AdamOptimizer( - params["learning_rate"]) - if params["use_tpu"]: - optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) - - with tf.GradientTape() as tape: - logits = model(features) - if mode == tf.estimator.ModeKeys.PREDICT: - preds = { - "predictions": logits - } - return tpu_estimator.TPUEstimatorSpec(mode, predictions=preds) - loss = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True)(labels, logits) - if mode == tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec(mode, loss=loss) - - def train_fn(use_compat): - assert optimizer is not None - gradient = tape.gradient(loss, model.trainable_variables) - global_step = tf.compat.v1.train.get_global_step() - apply_grads = tf.no_op() # Does Nothing. Initialization only. None would also work - if not use_compat: - update_global_step = tf.compat.v1.assign( - global_step, global_step + 1, name='update_global_step') - with tf.control_dependencies([update_global_step]): - apply_grads = optimizer.apply_gradients( - zip(gradient, model.trainable_variables)) - else: - apply_grads = optimizer.apply_gradients( - zip(gradient, model.trainable_variables), - global_step=global_step) - return apply_grads - - if mode == tf.estimator.ModeKeys.TRAIN: - return tpu_estimator.TPUEstimatorSpec( - mode, loss=loss, train_op=train_fn( - params['use_compat'])) - - -def main(_): - os.environ["TFHUB_CACHE_DIR"] = os.path.join( - FLAGS.model_dir, "tfhub_modules") - os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True" - input_fn = partial(input_, iterations=FLAGS.iterations) - cluster = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) - run_config = tpu_config.RunConfig( - model_dir=FLAGS.model_dir, - cluster=cluster, - tpu_config=tpu_config.TPUConfig(FLAGS.iterations)) - - classifier = tpu_estimator.TPUEstimator( - model_fn=model_fn, - use_tpu=FLAGS.use_tpu, - train_batch_size=FLAGS.batch_size, - eval_batch_size=FLAGS.batch_size, - config=run_config, - params={ - "use_tpu": FLAGS.use_tpu, - "data_dir": FLAGS.data_dir, - "dataset": FLAGS.dataset, - "use_compat": FLAGS.use_compat, - "learning_rate": FLAGS.learning_rate - } - ) - try: - classifier.train( - input_fn=lambda params: input_fn( - mode=tf.estimator.ModeKeys.TRAIN, - **params), - max_steps=FLAGS.max_steps) - except Exception: - pass - if FLAGS.infer: - def prepare_input_fn(path): - img = tf.image.decode_image(tf.io.read_file(path)) - return resize_and_scale(img, None) - - predictions = classifer.predict( - input_fn=lambda params: prepare_input_fn(FLAGS.infer)) - print(predictions) - - -if __name__ == "__main__": - app.run(main) diff --git a/E1_TPU_Sample/image_retraining_tpu_strategy.py b/E1_TPU_Sample/image_retraining_tpu_strategy.py new file mode 100644 index 0000000..0ed60aa --- /dev/null +++ b/E1_TPU_Sample/image_retraining_tpu_strategy.py @@ -0,0 +1,268 @@ +# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" TensorFlow Sample for running TPU Training """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import argparse +from absl import logging +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_datasets as tfds + +tf.compat.v2.enable_v2_behavior() +os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True" + +PRETRAINED_KERAS_LAYER = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4" +BATCH_SIZE = 32 # In case of TPU, Must be a multiple of 8 + + +class SingleDeviceStrategy(object): + """ Dummy Class to mimic tf.distribute.Strategy for Single Devices """ + + def __enter__(self, *args, **kwargs): + pass + + def __exit__(self, *args, **kwargs): + pass + + def scope(self): + return self + + def experimental_distribute_dataset(self, dataset): + return dataset + + def experimental_run_v2(self, func, args, kwargs): + return func(*args, **kwargs) + + def reduce(self, reduction_type, distributed_data, axis): # pylint: disable=unused-argument + return distributed_data + + +class Model(tf.keras.models.Model): + """ Keras Model class for Image Retraining """ + + def __init__(self, num_classes): + super(Model, self).__init__() + logging.info("Loading Pretrained Image Vectorizer") + self._pretrained_layer = hub.KerasLayer( + PRETRAINED_KERAS_LAYER, + output_shape=[2048], + trainable=False) + self._dense_1 = tf.keras.layers.Dense(num_classes, activation="sigmoid") + + @tf.function( + input_signature=[ + tf.TensorSpec( + shape=[None, None, None, 3], + dtype=tf.float32)]) + def call(self, inputs): + return self.unsigned_call(inputs) + + def unsigned_call(self, inputs): + intermediate = self._pretrained_layer(inputs) + return self._dense_1(intermediate) + + +def connect_to_tpu(tpu=None): + if tpu: + cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu) + tf.config.experimental_connect_to_host(cluster_resolver.get_master()) + tf.tpu.experimental.initialize_tpu_system(cluster_resolver) + strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) + return strategy, "/job:worker" + return SingleDeviceStrategy(), "" + + +def load_dataset(name, datadir, batch_size=32, shuffle=None): + """ + Loads and preprocesses dataset from TensorFlow dataset. + Args: + name: Name of the dataset to load + datadir: Directory to the dataset in. + batch_size: size of each minibatch. Must be a multiple of 8. + shuffle: size of shuffle buffer to use. Not shuffled if set to None. + """ + dataset, info = tfds.load( + name, + try_gcs=True, + data_dir=datadir, + split="train", + as_supervised=True, + with_info=True) + num_classes = info.features["label"].num_classes + + def _scale_fn(image, label): + image = tf.cast(image, tf.float32) + image = image / 127.5 + image -= 1. + label = tf.one_hot(label, num_classes) + label = tf.cast(label, tf.float32) + return image, label + + options = tf.data.Options() + if not hasattr(tf.data.Options, "auto_shard"): + options.experimental_distribute.auto_shard = False + else: + options.auto_shard = False + + dataset = ( + dataset.map( + _scale_fn, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + .with_options(options) + .batch(batch_size, drop_remainder=True)) + if shuffle: + dataset = dataset.shuffle(shuffle, reshuffle_each_iteration=True) + return dataset.repeat(), num_classes + + +def train_and_export(**kwargs): + """ + Trains the model and exports as SavedModel. + Args: + tpu: Name or GRPC address of the TPU to use. + logdir: Path to a bucket or directory to store TensorBoard logs. + modeldir: Path to a bucket or directory to store the model. + datadir: Path to store the downloaded datasets to. + dataset: Name of the dataset to load from TensorFlow Datasets. + num_steps: Number of steps to train the model for. + """ + if kwargs["tpu"]: + # For TPU Training the Files must be stored in + # Cloud Buckets for the TPU to access + if not kwargs["logdir"].startswith("gs://"): + raise ValueError("To train on TPU. `logdir` must be cloud bucket") + if not kwargs["modeldir"].startswith("gs://"): + raise ValueError("To train on TPU. `modeldir` must be cloud bucket") + if kwargs["datadir"]: + if not kwargs["datadir"].startswith("gs://"): + raise ValueError("To train on TPU. `datadir` must be a cloud bucket") + + os.environ["TFHUB_CACHE_DIR"] = os.path.join( + kwargs["modeldir"], "tfhub_cache") + + strategy, device = connect_to_tpu((not kwargs["export_only"]) and kwargs["tpu"]) + with tf.device(device), strategy.scope(): + summary_writer = tf.summary.create_file_writer(kwargs["logdir"]) + dataset, num_classes = load_dataset( + kwargs["dataset"], + kwargs["datadir"], + shuffle=3 * 32, + batch_size=BATCH_SIZE) + dataset = iter(strategy.experimental_distribute_dataset(dataset)) + model = Model(num_classes) + loss_metric = tf.keras.metrics.Mean() + optimizer = tf.keras.optimizers.Adam() + ckpt = tf.train.Checkpoint(model=model) + def distributed_step(images, labels): + with tf.GradientTape() as tape: + logging.info("Taking predictions") + predictions = model.unsigned_call(images) + logging.info("Calculating loss") + loss = tf.nn.sigmoid_cross_entropy_with_logits(labels, predictions) + loss_metric(loss) + loss = loss * (1.0 / BATCH_SIZE) + logging.info("Calculating gradients") + gradient = tape.gradient(loss, model.trainable_variables) + logging.info("Applying gradients") + train_op = optimizer.apply_gradients(zip(gradient, model.trainable_variables)) + with tf.control_dependencies([train_op]): + return tf.cast(optimizer.iterations, tf.float32) + + @tf.function + def train_step(image, label): + distributed_metric = strategy.experimental_run_v2( + distributed_step, args=[image, label]) + step = strategy.reduce( + tf.distribute.ReduceOp.MEAN, distributed_metric, axis=None) + return step + if not kwargs["export_only"]: + logging.info("Starting Training") + while not kwargs["export_only"]: + image, label = next(dataset) + step = tf.cast(train_step(image, label), tf.uint8) + with summary_writer.as_default(): + tf.summary.scalar("loss", loss_metric.result(), step=optimizer.iterations) + if step % 100: + logging.info("Step: #%f\tLoss: %f" % (step, loss_metric.result())) + if step >= kwargs["num_steps"]: + ckpt.save(file_prefix=os.path.join(kwargs["modeldir"], "checkpoint")) + break + logging.info("Exporting Saved Model") + export_path = (kwargs["export_path"] + or os.path.join(kwargs["modeldir"], "model")) + ckpt.restore( + tf.train.latest_checkpoint( + os.path.join(kwargs["modeldir"], "checkpoint"))) + logging.info("Consuming checkpoint and tracing function") + model(tf.random.normal([1, 200, 200, 3])) + tf.saved_model.save(model, export_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + default=None, + help="Name of the Dataset to use") + parser.add_argument( + "--datadir", + default=None, + help="Directory to store the downloaded Dataset") + parser.add_argument( + "--modeldir", + default=None, + help="Directory to store the SavedModel to") + parser.add_argument( + "--logdir", + default=None, + help="Directory to store the Tensorboard logs") + parser.add_argument( + "--tpu", + default=None, + help="name or GRPC address of the TPU") + parser.add_argument( + "--num_steps", + default=1000, + type=int, + help="Number of Steps to train the model for") + parser.add_argument( + "--export_path", + default=None, + help="Explicitly specify the export path of the model." + "Else `modeldir/model` wil be used.") + parser.add_argument( + "--export_only", + default=False, + action="store_true", + help="Only export the SavedModel from presaved checkpoints") + parser.add_argument( + "--verbose", + "-v", + default=0, + action="count", + help="increase verbosity. multiple tags to increase more") + flags, unknown = parser.parse_known_args() + log_levels = [logging.FATAL, logging.WARNING, logging.INFO, logging.DEBUG] + log_level = log_levels[min(flags.verbose, len(log_levels) - 1)] + if not flags.modeldir: + logging.fatal("`--modeldir` must be specified") + sys.exit(1) + logging.set_verbosity(log_level) + train_and_export(**vars(flags))