diff --git a/docs/conf.py b/docs/conf.py index 665409f1f..d41761007 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -67,9 +67,9 @@ # built documents. # # The short X.Y version. -version = '1.8.1' +version = '1.8.2' # The full version, including alpha/beta/rc tags. -release = '1.8.1' +release = '1.8.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -143,7 +143,7 @@ # The name for this set of Sphinx documents. # " v documentation" by default. # -# html_title = 'TensorLayer v1.8.1' +# html_title = 'TensorLayer v1.8.2' # A shorter title for the navigation bar. Default is the same as html_title. # diff --git a/docs/modules/activation.rst b/docs/modules/activation.rst index 8d7474674..51a8ba863 100644 --- a/docs/modules/activation.rst +++ b/docs/modules/activation.rst @@ -31,6 +31,7 @@ For more complex activation, TensorFlow API will be required. leaky_relu swish sign + hard_tanh pixel_wise_softmax Identity @@ -53,6 +54,10 @@ Sign --------------------- .. autofunction:: sign +Hard Tanh +--------------------- +.. autofunction:: hard_tanh + Pixel-wise softmax -------------------- .. autofunction:: pixel_wise_softmax diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index bb9918412..d87e5a8ac 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -152,7 +152,7 @@ At the end, for a layer with parameters, we also append the parameters into ``al name ='simple_dense', ): # check layer name (fixed) - Layer.__init__(self, name=name) + Layer.__init__(self, layer=layer, name=name) # the input of this layer is the output of previous layer (fixed) self.inputs = layer.outputs @@ -169,11 +169,6 @@ At the end, for a layer with parameters, we also append the parameters into ``al # tensor operation self.outputs = act(tf.matmul(self.inputs, W) + b) - # get stuff from previous layer (fixed) - self.all_layers = list(layer.all_layers) - self.all_params = list(layer.all_params) - self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.extend( [self.outputs] ) self.all_params.extend( [W, b] ) @@ -336,6 +331,11 @@ Layer list SlimNetsLayer + BinaryDenseLayer + BinaryConv2d + SignLayer + ScaleLayer + PReluLayer MultiplexerLayer @@ -799,6 +799,38 @@ see `Slim-model .. autoclass:: KerasLayer +Binary Nets +------------------ + +Read Me +^^^^^^^^^^^^^^ + +This is an experimental API package for building Binary Nets. +We are using matrix multiplication rather than add-minus and bit-count operation at the moment. +Therefore, these APIs would not speed up the inferencing, for production, you can train model via TensorLayer and deploy the model into other customized C/C++ implementation (We probably provide users an extra C/C++ binary net framework that can load model from TensorLayer). + +Note that, these experimental APIs can be changed in anytime. + +Binarized Dense +^^^^^^^^^^^^^^^^^ +.. autoclass:: BinaryDenseLayer + + +Binarized Conv2d +^^^^^^^^^^^^^^^^^^ +.. autoclass:: BinaryConv2d + + +Sign +^^^^^^^^^^^^^^ +.. autoclass:: SignLayer + + +Scale +^^^^^^^^^^^^^^ +.. autoclass:: ScaleLayer + + Parametric activation layer --------------------------- diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py index 326c0aa31..541739c9e 100644 --- a/example/tutorial_binarynet_mnist_cnn.py +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -7,6 +7,7 @@ X_train, y_train, X_val, y_val, X_test, y_test = \ tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) +# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) sess = tf.InteractiveSession() @@ -17,25 +18,29 @@ def model(x, is_train=True, reuse=False): + # In BNN, all the layers inputs are binary, with the exception of the first layer. + # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py with tf.variable_scope("binarynet", reuse=reuse): net = tl.layers.InputLayer(x, name='input') net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn') - net = tl.layers.SignLayer(net, name='sign2') + net = tl.layers.SignLayer(net) net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2') net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') - net = tl.layers.SignLayer(net, name='sign2') net = tl.layers.FlattenLayer(net, name='flatten') - net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1') - # net = tl.layers.DenseLayer(net, 256, act=tf.nn.relu, name='dense') + net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') + net = tl.layers.SignLayer(net) net = tl.layers.BinaryDenseLayer(net, 256, name='dense') - net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2') - # net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output') + net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') + + net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') + net = tl.layers.SignLayer(net) net = tl.layers.BinaryDenseLayer(net, 10, name='bout') - # net = tl.layers.ScaleLayer(net, name='scale') + net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') return net @@ -66,7 +71,7 @@ def model(x, is_train=True, reuse=False): n_epoch = 200 print_freq = 5 -# print(sess.run(net_test.all_params)) # print real value of parameters +# print(sess.run(net_test.all_params)) # print real values of parameters for epoch in range(n_epoch): start_time = time.time() diff --git a/setup.py b/setup.py index 64c026ecb..6483d7c97 100755 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="tensorlayer", - version="1.8.1", + version="1.8.2", include_package_data=True, author='TensorLayer Contributors', author_email='hao.dong11@imperial.ac.uk', diff --git a/tensorlayer/__init__.py b/tensorlayer/__init__.py index 799353f75..5d3d8cc86 100644 --- a/tensorlayer/__init__.py +++ b/tensorlayer/__init__.py @@ -23,7 +23,7 @@ act = activation vis = visualize -__version__ = "1.8.1" +__version__ = "1.8.2" global_flag = {} global_dict = {} diff --git a/tensorlayer/activation.py b/tensorlayer/activation.py index 9a0c12896..2bf741b97 100644 --- a/tensorlayer/activation.py +++ b/tensorlayer/activation.py @@ -168,6 +168,28 @@ def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models # return tf.sign(x), grad +def hard_tanh(x, name='htanh'): + """Hard tanh activation function. + + Which is a ramp function with low bound of -1 and upper bound of 1, shortcut is ``htanh`. + + Parameters + ---------- + x : Tensor + input. + name : str + The function name (optional). + + Returns + ------- + Tensor + A ``Tensor`` in the same type as ``x``. + + """ + # with tf.variable_scope("hard_tanh"): + return tf.clip_by_value(x, -1, 1, name=name) + + @deprecated("2018-06-30", "This API will be deprecated soon as tf.nn.softmax can do the same thing.") def pixel_wise_softmax(x, name='pixel_wise_softmax'): """Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1. @@ -204,3 +226,4 @@ def pixel_wise_softmax(x, name='pixel_wise_softmax'): # Alias linear = identity lrelu = leaky_relu +htanh = hard_tanh diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index 9d0578ee3..405e2d1fe 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -5,9 +5,9 @@ __all__ = [ 'BinaryDenseLayer', + 'BinaryConv2d', 'SignLayer', 'ScaleLayer', - 'BinaryConv2d', ] @@ -142,6 +142,18 @@ class BinaryConv2d(Layer): name : str A unique layer name. + Examples + --------- + >>> net = tl.layers.InputLayer(x, name='input') + >>> net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') + >>> net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') + >>> net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') + ... + >>> net = tl.layers.SignLayer(net) + >>> net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2') + >>> net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') + >>> net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') + """ def __init__(