Skip to content

Commit

Permalink
add SlimNetsLayer and Inception V3 example / Merge TF-Slim into Tenso…
Browse files Browse the repository at this point in the history
…rLayer
  • Loading branch information
zsdonghao committed Sep 28, 2016
1 parent 7594891 commit b160de7
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 3 deletions.
18 changes: 18 additions & 0 deletions docs/modules/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ In addition, if you want to update the parameters of previous 2 layers at the sa
FlattenLayer
ConcatLayer
ReshapeLayer
SlimNetsLayer
MultiplexerLayer
EmbeddingAttentionSeq2seqWrapper
flatten_reshape
Expand Down Expand Up @@ -357,10 +358,20 @@ so to implement 1D CNN, you can use Reshape layer as follow.

.. autoclass:: Conv2dLayer

2D Deconvolutional layer
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DeConv2dLayer


3D Convolutional layer
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: Conv3dLayer

3D Deconvolutional layer
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DeConv3dLayer

Pooling layer
Expand Down Expand Up @@ -397,6 +408,13 @@ Reshape layer

.. autoclass:: ReshapeLayer

Merge TF-Slim
^^^^^^^^^^^^^^^

Yes ! TF-Slim models can be merged into TensorLayer, all Google's Pre-trained model can be used easily ,
see `Slim-model <https://github.com/tensorflow/models/tree/master/slim#Install>`_ .

.. autoclass:: SlimNetsLayer

Flow control layer
----------------------
Expand Down
8 changes: 8 additions & 0 deletions docs/user/example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Computer Vision
- Convolutional Network (CIFAR-10). A Convolutional neural network implementation for classifying CIFAR-10 dataset, see ``tutorial_cifar10.py`` and ``tutorial_cifar10_tfrecord.py``on `GitHub`_.
- VGG 16 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg16.py`` on `GitHub`_.
- VGG 19 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg19.py`` on `GitHub`_.
- InceptionV3 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_.


Natural Language Processing
Expand All @@ -36,6 +37,13 @@ Reinforcement Learning
- Deep Reinforcement Learning - Pong Game. Teach a machine to play Pong games, see ``tutorial_atari_pong.py`` on `GitHub`_.


Special Examples
=================

- Merge TF-Slim into TensorLayer. ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_.
- MultiplexerLayer. ``tutorial_mnist_multiplexer.py`` on `GitHub`_.


..
Applications
=============
Expand Down
52 changes: 52 additions & 0 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,58 @@ def __init__(
self.all_drop = dict(layer.all_drop)
self.all_layers.extend( [self.outputs] )

## TF-Slim layer
class SlimNetsLayer(Layer):
"""
The :class:`SlimNetsLayer` class can be used to merge all TF-Slim nets into
TensorLayer. Model can be found in `slim-model <https://github.com/tensorflow/models/tree/master/slim#Install>`_ , more about slim
see `slim-git <https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim>`_ .
Parameters
----------
layer : a list of :class:`Layer` instances
The `Layer` class feeding into this layer.
slim_layer : a slim network function
The network you want to stack onto, end with ``return net, end_points``.
name : a string or None
An optional name to attach to this layer.
Note
-----
The due to TF-Slim stores the layers as dictionary, the ``all_layers`` in this
network is not in order ! Fortunately, the ``all_params`` are in order.
"""
def __init__(
self,
layer = None,
slim_layer = None,
slim_args = {},
name ='slim_layer',
):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__))

with tf.variable_scope(name) as vs:
net, end_points = slim_layer(self.inputs, **slim_args)
slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

self.outputs = net

slim_layers = []
for v in end_points.values():
tf.contrib.layers.summaries.summarize_activation(v)
slim_layers.append(v)

self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)

self.all_layers.extend( slim_layers )
self.all_params.extend( slim_variables )


## Flow control layer
class MultiplexerLayer(Layer):
"""
Expand Down
11 changes: 8 additions & 3 deletions tensorlayer/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def get_site_packages_directory():
"""Print and return the site-packages directory?
"""
import site
loc = site.getsitepackages()
print(loc)
return loc
try:
loc = site.getsitepackages()
print(" tl.ops : site-packages in ", loc)
return loc
except:
p = ' tl.ops : You are using virtual environment'
print(p)
return p



Expand Down
131 changes: 131 additions & 0 deletions tutorial_inceptionV3_tfslim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#! /usr/bin/python
# -*- coding: utf8 -*-


import tensorflow as tf
import tensorlayer as tl
slim = tf.contrib.slim
from tensorflow.contrib.slim.python.slim.nets.alexnet import alexnet_v2
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base, inception_v3, inception_v3_arg_scope
# from tensorflow.contrib.slim.python.slim.nets.resnet_v2 import resnet_v2_152
# from tensorflow.contrib.slim.python.slim.nets.vgg import vgg_16
import skimage
import skimage.io
import skimage.transform
import time
from data.imagenet_classes import *
import numpy as np
"""
You will learn:
1. What is TF-Slim ?
1. How to combine TensorLayer and TF-Slim ?
Introduction of Slim : https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim
Slim Pre-trained Models : https://github.com/tensorflow/models/tree/master/slim
With the help of SlimNetsLayer, all Slim Model can be combined into TensorLayer.
All models in the following link, end with `return net, end_points`` are available.
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim/nets
Bugs
-----
tf.variable_scope :
https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/RoxrU3UnbFA
load inception_v3 for prediction:
http://stackoverflow.com/questions/39357454/restore-checkpoint-in-tensorflow-tensor-name-not-found
"""
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
# print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (299, 299))
return resized_img


def print_prob(prob):
synset = class_names
# print prob
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
print("Top1: ", top1, prob[pred[0]])
# Get top5 label
top5 = [(synset[pred[i]], prob[pred[i]]) for i in range(5)]
print("Top5: ", top5)
return top1


## Alexnet_v2 / All Slim nets can be merged into TensorLayer
# x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3])
# net_in = tl.layers.InputLayer(x, name='input_layer')
# network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=alexnet_v2,
# slim_args= {
# 'num_classes' : 1000,
# 'is_training' : True,
# 'dropout_keep_prob' : 0.5,
# 'spatial_squeeze' : True,
# 'scope' : 'alexnet_v2'
# }
# )
# sess = tf.InteractiveSession()
# sess.run(tf.initialize_all_variables())
# network.print_params()
# exit()

# InceptionV3
x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3])
net_in = tl.layers.InputLayer(x, name='input_layer') # DH
with slim.arg_scope(inception_v3_arg_scope()):
# logits, end_points = inception_v3(X, num_classes=1001,
# is_training=False)
network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=inception_v3,
slim_args= {
'num_classes' : 1001,
'is_training' : False,
# 'dropout_keep_prob' : 0.8, # for training
# 'min_depth' : 16,
# 'depth_multiplier' : 1.0,
# 'prediction_fn' : slim.softmax,
# 'spatial_squeeze' : True,
# 'reuse' : None,
# 'scope' : 'InceptionV3'
},
name=''
)
saver = tf.train.Saver()

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

# with tf.Session() as sess:
saver.restore(sess, "inception_v3.ckpt") # download from https://github.com/tensorflow/models/tree/master/slim#Install
print("Model Restored")
network.print_params(False)


from scipy.misc import imread, imresize
y = network.outputs
probs = tf.nn.softmax(y)
img1 = load_image("data/puzzle.jpeg")
img1 = img1.reshape((1, 299, 299, 3))

start_time = time.time()
prob = sess.run(probs, feed_dict= {x : img1})
print("End time : %.5ss" % (time.time() - start_time))
print_prob(prob[0][1:]) # Note : as it have 1001 outputs, the 1st output is nothing






#

0 comments on commit b160de7

Please sign in to comment.