diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 31a8fac3a..d0f30daed 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -270,6 +270,7 @@ In addition, if you want to update the parameters of previous 2 layers at the sa FlattenLayer ConcatLayer ReshapeLayer + SlimNetsLayer MultiplexerLayer EmbeddingAttentionSeq2seqWrapper flatten_reshape @@ -357,10 +358,20 @@ so to implement 1D CNN, you can use Reshape layer as follow. .. autoclass:: Conv2dLayer +2D Deconvolutional layer +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: DeConv2dLayer + + 3D Convolutional layer ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: Conv3dLayer + +3D Deconvolutional layer +^^^^^^^^^^^^^^^^^^^^^^^^^^ + .. autoclass:: DeConv3dLayer Pooling layer @@ -397,6 +408,13 @@ Reshape layer .. autoclass:: ReshapeLayer +Merge TF-Slim +^^^^^^^^^^^^^^^ + +Yes ! TF-Slim models can be merged into TensorLayer, all Google's Pre-trained model can be used easily , +see `Slim-model `_ . + +.. autoclass:: SlimNetsLayer Flow control layer ---------------------- diff --git a/docs/user/example.rst b/docs/user/example.rst index 13fd36981..11ccea67b 100755 --- a/docs/user/example.rst +++ b/docs/user/example.rst @@ -19,6 +19,7 @@ Computer Vision - Convolutional Network (CIFAR-10). A Convolutional neural network implementation for classifying CIFAR-10 dataset, see ``tutorial_cifar10.py`` and ``tutorial_cifar10_tfrecord.py``on `GitHub`_. - VGG 16 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg16.py`` on `GitHub`_. - VGG 19 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg19.py`` on `GitHub`_. + - InceptionV3 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_. Natural Language Processing @@ -36,6 +37,13 @@ Reinforcement Learning - Deep Reinforcement Learning - Pong Game. Teach a machine to play Pong games, see ``tutorial_atari_pong.py`` on `GitHub`_. +Special Examples +================= + + - Merge TF-Slim into TensorLayer. ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_. + - MultiplexerLayer. ``tutorial_mnist_multiplexer.py`` on `GitHub`_. + + .. Applications ============= diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index 40c0313fe..da105ac5b 100644 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -1629,6 +1629,58 @@ def __init__( self.all_drop = dict(layer.all_drop) self.all_layers.extend( [self.outputs] ) +## TF-Slim layer +class SlimNetsLayer(Layer): + """ + The :class:`SlimNetsLayer` class can be used to merge all TF-Slim nets into + TensorLayer. Model can be found in `slim-model `_ , more about slim + see `slim-git `_ . + + Parameters + ---------- + layer : a list of :class:`Layer` instances + The `Layer` class feeding into this layer. + slim_layer : a slim network function + The network you want to stack onto, end with ``return net, end_points``. + name : a string or None + An optional name to attach to this layer. + + Note + ----- + The due to TF-Slim stores the layers as dictionary, the ``all_layers`` in this + network is not in order ! Fortunately, the ``all_params`` are in order. + + """ + def __init__( + self, + layer = None, + slim_layer = None, + slim_args = {}, + name ='slim_layer', + ): + Layer.__init__(self, name=name) + self.inputs = layer.outputs + print(" tensorlayer:Instantiate SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__)) + + with tf.variable_scope(name) as vs: + net, end_points = slim_layer(self.inputs, **slim_args) + slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) + + self.outputs = net + + slim_layers = [] + for v in end_points.values(): + tf.contrib.layers.summaries.summarize_activation(v) + slim_layers.append(v) + + self.all_layers = list(layer.all_layers) + self.all_params = list(layer.all_params) + self.all_drop = dict(layer.all_drop) + + self.all_layers.extend( slim_layers ) + self.all_params.extend( slim_variables ) + + ## Flow control layer class MultiplexerLayer(Layer): """ diff --git a/tensorlayer/ops.py b/tensorlayer/ops.py index 4818d3c80..50129422f 100644 --- a/tensorlayer/ops.py +++ b/tensorlayer/ops.py @@ -156,9 +156,14 @@ def get_site_packages_directory(): """Print and return the site-packages directory? """ import site - loc = site.getsitepackages() - print(loc) - return loc + try: + loc = site.getsitepackages() + print(" tl.ops : site-packages in ", loc) + return loc + except: + p = ' tl.ops : You are using virtual environment' + print(p) + return p diff --git a/tutorial_inceptionV3_tfslim.py b/tutorial_inceptionV3_tfslim.py new file mode 100644 index 000000000..61e1fb247 --- /dev/null +++ b/tutorial_inceptionV3_tfslim.py @@ -0,0 +1,131 @@ +#! /usr/bin/python +# -*- coding: utf8 -*- + + +import tensorflow as tf +import tensorlayer as tl +slim = tf.contrib.slim +from tensorflow.contrib.slim.python.slim.nets.alexnet import alexnet_v2 +from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base, inception_v3, inception_v3_arg_scope +# from tensorflow.contrib.slim.python.slim.nets.resnet_v2 import resnet_v2_152 +# from tensorflow.contrib.slim.python.slim.nets.vgg import vgg_16 +import skimage +import skimage.io +import skimage.transform +import time +from data.imagenet_classes import * +import numpy as np +""" +You will learn: +1. What is TF-Slim ? +1. How to combine TensorLayer and TF-Slim ? + +Introduction of Slim : https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim +Slim Pre-trained Models : https://github.com/tensorflow/models/tree/master/slim + +With the help of SlimNetsLayer, all Slim Model can be combined into TensorLayer. +All models in the following link, end with `return net, end_points`` are available. +https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim/nets + + +Bugs +----- +tf.variable_scope : + https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/RoxrU3UnbFA +load inception_v3 for prediction: + http://stackoverflow.com/questions/39357454/restore-checkpoint-in-tensorflow-tensor-name-not-found +""" +def load_image(path): + # load image + img = skimage.io.imread(path) + img = img / 255.0 + assert (0 <= img).all() and (img <= 1.0).all() + # print "Original Image Shape: ", img.shape + # we crop image from center + short_edge = min(img.shape[:2]) + yy = int((img.shape[0] - short_edge) / 2) + xx = int((img.shape[1] - short_edge) / 2) + crop_img = img[yy: yy + short_edge, xx: xx + short_edge] + # resize to 224, 224 + resized_img = skimage.transform.resize(crop_img, (299, 299)) + return resized_img + + +def print_prob(prob): + synset = class_names + # print prob + pred = np.argsort(prob)[::-1] + # Get top1 label + top1 = synset[pred[0]] + print("Top1: ", top1, prob[pred[0]]) + # Get top5 label + top5 = [(synset[pred[i]], prob[pred[i]]) for i in range(5)] + print("Top5: ", top5) + return top1 + + +## Alexnet_v2 / All Slim nets can be merged into TensorLayer +# x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3]) +# net_in = tl.layers.InputLayer(x, name='input_layer') +# network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=alexnet_v2, +# slim_args= { +# 'num_classes' : 1000, +# 'is_training' : True, +# 'dropout_keep_prob' : 0.5, +# 'spatial_squeeze' : True, +# 'scope' : 'alexnet_v2' +# } +# ) +# sess = tf.InteractiveSession() +# sess.run(tf.initialize_all_variables()) +# network.print_params() +# exit() + +# InceptionV3 +x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3]) +net_in = tl.layers.InputLayer(x, name='input_layer') # DH +with slim.arg_scope(inception_v3_arg_scope()): + # logits, end_points = inception_v3(X, num_classes=1001, + # is_training=False) + network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=inception_v3, + slim_args= { + 'num_classes' : 1001, + 'is_training' : False, + # 'dropout_keep_prob' : 0.8, # for training + # 'min_depth' : 16, + # 'depth_multiplier' : 1.0, + # 'prediction_fn' : slim.softmax, + # 'spatial_squeeze' : True, + # 'reuse' : None, + # 'scope' : 'InceptionV3' + }, + name='' + ) +saver = tf.train.Saver() + +sess = tf.InteractiveSession() +sess.run(tf.initialize_all_variables()) + +# with tf.Session() as sess: +saver.restore(sess, "inception_v3.ckpt") # download from https://github.com/tensorflow/models/tree/master/slim#Install +print("Model Restored") +network.print_params(False) + + +from scipy.misc import imread, imresize +y = network.outputs +probs = tf.nn.softmax(y) +img1 = load_image("data/puzzle.jpeg") +img1 = img1.reshape((1, 299, 299, 3)) + +start_time = time.time() +prob = sess.run(probs, feed_dict= {x : img1}) +print("End time : %.5ss" % (time.time() - start_time)) +print_prob(prob[0][1:]) # Note : as it have 1001 outputs, the 1st output is nothing + + + + + + +#