From 32585cb1fe2ecae4623b71172451812cb9da0747 Mon Sep 17 00:00:00 2001 From: Sander Dieleman Date: Sat, 22 Aug 2015 23:46:50 +0200 Subject: [PATCH 1/2] highway networks implementation --- papers/Highway Networks.ipynb | 640 ++++++++++++++++++++++++++++++++++ 1 file changed, 640 insertions(+) create mode 100644 papers/Highway Networks.ipynb diff --git a/papers/Highway Networks.ipynb b/papers/Highway Networks.ipynb new file mode 100644 index 0000000..a891bab --- /dev/null +++ b/papers/Highway Networks.ipynb @@ -0,0 +1,640 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Highway networks\n", + "**A quick example of how to implement [highway networks](http://arxiv.org/abs/1505.00387) in Lasagne.**\n", + "\n", + "## What's a highway network?\n", + "The paper linked above introduces a new type of neural network layer, which works roughly as follows.\n", + "\n", + "Let $x$ be the input to a layer. Then a typical neural network layer computes some nonlinear transform of this input $y = H(x)$.\n", + "\n", + "A highway layer also computes an additional nonlinear transform $T(x)$, which in practice is constrained to the interval $[0, 1]$. The output of the layer is then $y = T(x) \\cdot H(x) + (1 - T(x)) \\cdot x$, where the multiplication is elementwise.\n", + "\n", + "In other words, **depending on the gate values $T(x)$, the layer behaves as a traditional layer would ($T(x) = 1$), or passes its input through unchanged ($T(x) = 0$)**. This idea is inspired by the gates in LSTM units. According to the authors, **it enables gradient descent-based training of much deeper networks** with as many as 900 layers.\n", + "\n", + "Note that a highway layer needs to have as many outputs as inputs: the shapes of $x$, $H(x)$ and $T(x)$ all have to have matching shapes. To change the dimensionality in a highway network, the authors suggest inserting a traditional neural network layer.\n", + "\n", + "## Purpose\n", + "I read the paper before and wanted to try it out. I figured this would be a good way of showing how to implement a new concept or idea in Lasagne. **This use case of trying out new ideas and implementing new types of layers is extremely important to the Lasagne development team, and we are trying to make it as easy as possible to use Lasagne in this way.**\n", + "\n", + "## Approach\n", + "I decided to implement this idea in two steps. First, I added a `MultiplicativeGatingLayer`, which performs the following operation: $y(t, x_1, x_2) = t \\cdot x_1 + (1 - t) \\cdot x_2$. In other words, the first input $t$ multiplicatively gates between the others $x_1$ and $x_2$.\n", + "\n", + "This then makes it possible to use any layer we like for computing $t$ and $x_1$ (and $x_2$ is taken to be the output of the previous layer). I implemented two \"macro functions\" on top of this: `highway_dense` and `highway_conv2d`. They create fully connected and 2D convolutional highway layers respectively.\n", + "\n", + "This two-step approach allows for some code reuse and easy implementation of different types of highway layers." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using gpu device 0: GeForce GTX 980\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import theano\n", + "import theano.tensor as T\n", + "import lasagne as nn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The MNIST dataset (15MB) can be downloaded with:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2015-08-22 23:00:18-- http://deeplearning.net/data/mnist/mnist.pkl.gz\n", + "Resolving deeplearning.net (deeplearning.net)... 132.204.26.28\n", + "Connecting to deeplearning.net (deeplearning.net)|132.204.26.28|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 16168813 (15M) [application/x-gzip]\n", + "Server file no newer than local file ‘mnist.pkl.gz’ -- not retrieving.\n", + "\n" + ] + } + ], + "source": [ + "!wget -N http://deeplearning.net/data/mnist/mnist.pkl.gz" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import gzip\n", + "import cPickle as pickle\n", + "import sys\n", + "\n", + "PY2 = sys.version_info[0] == 2 # check if we're running Python 2 or 3\n", + "# we need to know this because unpickling is slightly different in both cases\n", + "\n", + "if PY2:\n", + " def pickle_load(f, encoding):\n", + " return pickle.load(f)\n", + "else:\n", + " def pickle_load(f, encoding):\n", + " return pickle.load(f, encoding=encoding)\n", + "\n", + "def load_data():\n", + " \"\"\"Get data with labels, split into training, validation and test set.\"\"\"\n", + " with gzip.open('mnist.pkl.gz', 'rb') as f:\n", + " data = pickle_load(f, encoding='latin-1')\n", + " X_train, y_train = data[0]\n", + " X_valid, y_valid = data[1]\n", + " X_test, y_test = data[2]\n", + "\n", + " return dict(\n", + " X_train=theano.shared(nn.utils.floatX(X_train)),\n", + " y_train=T.cast(theano.shared(y_train), 'int32'),\n", + " X_valid=theano.shared(nn.utils.floatX(X_valid)),\n", + " y_valid=T.cast(theano.shared(y_valid), 'int32'),\n", + " X_test=theano.shared(nn.utils.floatX(X_test)),\n", + " y_test=T.cast(theano.shared(y_test), 'int32'),\n", + " num_examples_train=X_train.shape[0],\n", + " num_examples_valid=X_valid.shape[0],\n", + " num_examples_test=X_test.shape[0],\n", + " input_dim=X_train.shape[1],\n", + " output_dim=10,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**First, create a custom layer class for the multiplicative gating operation.** This is a layer with multiple input layers, three to be precise: $x$, $H(x)$ and $T(x)$. In Lasagne, this means it needs to inherit from `MergeLayer`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MultiplicativeGatingLayer(nn.layers.MergeLayer):\n", + " \"\"\"\n", + " Generic layer that combines its 3 inputs t, h1, h2 as follows:\n", + " y = t * h1 + (1 - t) * h2\n", + " \"\"\"\n", + " def __init__(self, gate, input1, input2, **kwargs):\n", + " incomings = [gate, input1, input2]\n", + " super(MultiplicativeGatingLayer, self).__init__(incomings, **kwargs)\n", + " assert gate.output_shape == input1.output_shape == input2.output_shape\n", + " \n", + " def get_output_shape_for(self, input_shapes):\n", + " return input_shapes[0]\n", + " \n", + " def get_output_for(self, inputs, **kwargs):\n", + " return inputs[0] * inputs[1] + (1 - inputs[0]) * inputs[2]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Now we can define a macro function to create a dense highway layer.** Note that it does not take a `num_units` input argument: the number of outputs should always be the same as the number of inputs, so it is redundant.\n", + "\n", + "We initialize the biases of the gates to `-4.0` to disable all of them initially. This means all layers will basically pass through the inputs (and gradients) unchanged at the start of training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def highway_dense(incoming, Wh=nn.init.Orthogonal(), bh=nn.init.Constant(0.0),\n", + " Wt=nn.init.Orthogonal(), bt=nn.init.Constant(-4.0),\n", + " nonlinearity=nn.nonlinearities.rectify, **kwargs):\n", + " num_inputs = int(np.prod(incoming.output_shape[1:]))\n", + " # regular layer\n", + " l_h = nn.layers.DenseLayer(incoming, num_units=num_inputs, W=Wh, b=bh,\n", + " nonlinearity=nonlinearity)\n", + " # gate layer\n", + " l_t = nn.layers.DenseLayer(incoming, num_units=num_inputs, W=Wt, b=bt,\n", + " nonlinearity=T.nnet.sigmoid)\n", + " \n", + " return MultiplicativeGatingLayer(gate=l_t, input1=l_h, input2=incoming)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**We can easily do the same for a 2D convolution highway layer.** As mentioned in the paper, we need to use 'same' convolutions here to ensure that the shape of $H(x)$ and $T(x)$ matches that of $x$.\n", + "\n", + "Unfortunately the implementation of 'same' convolutions in Theano using the default convolution operations `T.nnet.conv.conv2d` is a bit challenging. The default approach in Lasagne is to perform a 'full' convolution and then crop it, which can be slow. This is implemented in `lasagne.layers.Conv2DLayer`.\n", + "\n", + "To get an actual 'same' convolution, you could use one of the alternative convolution layer implementations that Lasagne provides, such as `lasagne.layers.dnn.Conv2DDNNLayer`, `lasagne.layers.corrmm.Conv2DMMLayer` or `lasagne.layers.cuda_convnet.Conv2DCCLayer`, all of which support the 'same' convolution mode properly." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def highway_conv2d(incoming, filter_size,\n", + " Wh=nn.init.Orthogonal(), bh=nn.init.Constant(0.0),\n", + " Wt=nn.init.Orthogonal(), bt=nn.init.Constant(-4.0),\n", + " nonlinearity=nn.nonlinearities.rectify, **kwargs):\n", + " num_channels = incoming.output_shape[1]\n", + " # regular layer\n", + " l_h = nn.layers.Conv2DLayer(incoming, num_filters=num_channels,\n", + " filter_size=filter_size,\n", + " border_mode='same', W=Wh, b=bh,\n", + " nonlinearity=nonlinearity)\n", + " # gate layer\n", + " l_t = nn.layers.Conv2DLayer(incoming, num_filters=num_channels,\n", + " filter_size=filter_size,\n", + " border_mode='same', W=wt, b=bt,\n", + " nonlinearity=T.nnet.sigmoid)\n", + " \n", + " return MultiplicativeGatingLayer(gate=l_t, input1=l_h, input2=incoming)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "Now let's **build a model** with a number of dense highway layers." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def build_model(input_dim, output_dim, batch_size,\n", + " num_hidden_units, num_hidden_layers):\n", + " \"\"\"Create a symbolic representation of a neural network with `intput_dim`\n", + " input nodes, `output_dim` output nodes, `num_hidden_layers` hidden layers\n", + " and `num_hidden_units` per hidden layer.\n", + " \n", + " The training function of this model must have a mini-batch size of\n", + " `batch_size`.\n", + " \"\"\"\n", + " l_in = nn.layers.InputLayer((batch_size, input_dim))\n", + " \n", + " # first, project it down to the desired number of units per layer\n", + " l_hidden1 = nn.layers.DenseLayer(l_in, num_units=num_hidden_units)\n", + " \n", + " # then stack highway layers on top of this\n", + " l_current = l_hidden1\n", + " for k in range(num_hidden_layers - 1):\n", + " l_current = highway_dense(l_current)\n", + " \n", + " # finally add an output layer\n", + " l_out = nn.layers.DenseLayer(\n", + " l_current, num_units=output_dim,\n", + " nonlinearity=nn.nonlinearities.softmax,\n", + " )\n", + " \n", + " return l_in, l_out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can **load the data, build the model and compile the necessary Theano functions.**" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading data...\n", + "Building model and compiling functions...\n" + ] + } + ], + "source": [ + "num_epochs = 50\n", + "batch_size = 100\n", + "learning_rate = 0.01\n", + "momentum = 0.9\n", + "\n", + "print(\"Loading data...\")\n", + "dataset = load_data()\n", + "\n", + "print(\"Building model and compiling functions...\")\n", + "l_in, l_out = build_model(\n", + " input_dim=dataset['input_dim'],\n", + " output_dim=dataset['output_dim'],\n", + " batch_size=batch_size,\n", + " num_hidden_units=40,\n", + " num_hidden_layers=50,\n", + ")\n", + "\n", + "x = l_in.input_var\n", + "y = T.ivector('y')\n", + "y_pred = nn.layers.get_output(l_out)\n", + "loss = T.mean(nn.objectives.categorical_crossentropy(y_pred, y))\n", + "params = nn.layers.get_all_params(l_out)\n", + "updates = nn.updates.nesterov_momentum(loss, params, learning_rate, momentum)\n", + "\n", + "# compile iteration functions\n", + "batch_index = T.iscalar('batch_index')\n", + "batch_slice = slice(batch_index * batch_size,\n", + " (batch_index + 1) * batch_size)\n", + "\n", + "pred = T.argmax(y_pred, axis=1)\n", + "accuracy = T.mean(T.eq(pred, y), dtype=theano.config.floatX)\n", + "\n", + "iter_train = theano.function(\n", + " [batch_index], loss,\n", + " updates=updates,\n", + " givens={\n", + " x: dataset['X_train'][batch_slice],\n", + " y: dataset['y_train'][batch_slice],\n", + " },\n", + ")\n", + "\n", + "iter_valid = theano.function(\n", + " [batch_index], [loss, accuracy],\n", + " givens={\n", + " x: dataset['X_valid'][batch_slice],\n", + " y: dataset['y_valid'][batch_slice],\n", + " },\n", + ")\n", + "\n", + "iter_test = theano.function(\n", + " [batch_index], [loss, accuracy],\n", + " givens={\n", + " x: dataset['X_test'][batch_slice],\n", + " y: dataset['y_test'][batch_slice],\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, here's the **main training loop**." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting training...\n", + "Epoch 1 of 50 took 5.509 s\n", + " training loss:\t\t0.903543\n", + " validation loss:\t\t0.337289\n", + " validation accuracy:\t\t90.37 %\n", + "Epoch 2 of 50 took 5.529 s\n", + " training loss:\t\t0.314176\n", + " validation loss:\t\t0.232568\n", + " validation accuracy:\t\t92.90 %\n", + "Epoch 3 of 50 took 5.574 s\n", + " training loss:\t\t0.239013\n", + " validation loss:\t\t0.190169\n", + " validation accuracy:\t\t94.42 %\n", + "Epoch 4 of 50 took 5.588 s\n", + " training loss:\t\t0.196143\n", + " validation loss:\t\t0.166020\n", + " validation accuracy:\t\t95.14 %\n", + "Epoch 5 of 50 took 5.602 s\n", + " training loss:\t\t0.166447\n", + " validation loss:\t\t0.151690\n", + " validation accuracy:\t\t95.66 %\n", + "Epoch 6 of 50 took 5.618 s\n", + " training loss:\t\t0.144108\n", + " validation loss:\t\t0.140717\n", + " validation accuracy:\t\t96.02 %\n", + "Epoch 7 of 50 took 5.608 s\n", + " training loss:\t\t0.126167\n", + " validation loss:\t\t0.134176\n", + " validation accuracy:\t\t96.18 %\n", + "Epoch 8 of 50 took 5.614 s\n", + " training loss:\t\t0.111839\n", + " validation loss:\t\t0.129229\n", + " validation accuracy:\t\t96.30 %\n", + "Epoch 9 of 50 took 5.623 s\n", + " training loss:\t\t0.099244\n", + " validation loss:\t\t0.125851\n", + " validation accuracy:\t\t96.53 %\n", + "Epoch 10 of 50 took 5.599 s\n", + " training loss:\t\t0.089900\n", + " validation loss:\t\t0.126563\n", + " validation accuracy:\t\t96.49 %\n", + "Epoch 11 of 50 took 5.611 s\n", + " training loss:\t\t0.082267\n", + " validation loss:\t\t0.124533\n", + " validation accuracy:\t\t96.67 %\n", + "Epoch 12 of 50 took 5.617 s\n", + " training loss:\t\t0.074911\n", + " validation loss:\t\t0.121908\n", + " validation accuracy:\t\t96.75 %\n", + "Epoch 13 of 50 took 5.596 s\n", + " training loss:\t\t0.069080\n", + " validation loss:\t\t0.123284\n", + " validation accuracy:\t\t96.76 %\n", + "Epoch 14 of 50 took 5.600 s\n", + " training loss:\t\t0.062342\n", + " validation loss:\t\t0.125314\n", + " validation accuracy:\t\t96.75 %\n", + "Epoch 15 of 50 took 5.604 s\n", + " training loss:\t\t0.058737\n", + " validation loss:\t\t0.122159\n", + " validation accuracy:\t\t96.89 %\n", + "Epoch 16 of 50 took 5.599 s\n", + " training loss:\t\t0.056600\n", + " validation loss:\t\t0.126851\n", + " validation accuracy:\t\t96.65 %\n", + "Epoch 17 of 50 took 5.601 s\n", + " training loss:\t\t0.049401\n", + " validation loss:\t\t0.127822\n", + " validation accuracy:\t\t96.80 %\n", + "Epoch 18 of 50 took 5.597 s\n", + " training loss:\t\t0.051736\n", + " validation loss:\t\t0.124382\n", + " validation accuracy:\t\t96.93 %\n", + "Epoch 19 of 50 took 5.728 s\n", + " training loss:\t\t0.053274\n", + " validation loss:\t\t0.121892\n", + " validation accuracy:\t\t96.90 %\n", + "Epoch 20 of 50 took 5.606 s\n", + " training loss:\t\t0.041346\n", + " validation loss:\t\t0.132626\n", + " validation accuracy:\t\t96.83 %\n", + "Epoch 21 of 50 took 5.606 s\n", + " training loss:\t\t0.042719\n", + " validation loss:\t\t0.136224\n", + " validation accuracy:\t\t96.62 %\n", + "Epoch 22 of 50 took 5.596 s\n", + " training loss:\t\t0.036241\n", + " validation loss:\t\t0.128471\n", + " validation accuracy:\t\t96.92 %\n", + "Epoch 23 of 50 took 5.625 s\n", + " training loss:\t\t0.036981\n", + " validation loss:\t\t0.134090\n", + " validation accuracy:\t\t96.96 %\n", + "Epoch 24 of 50 took 5.596 s\n", + " training loss:\t\t0.036138\n", + " validation loss:\t\t0.131586\n", + " validation accuracy:\t\t96.97 %\n", + "Epoch 25 of 50 took 5.590 s\n", + " training loss:\t\t0.034088\n", + " validation loss:\t\t0.129220\n", + " validation accuracy:\t\t97.04 %\n", + "Epoch 26 of 50 took 5.596 s\n", + " training loss:\t\t0.030038\n", + " validation loss:\t\t0.139834\n", + " validation accuracy:\t\t96.84 %\n", + "Epoch 27 of 50 took 5.595 s\n", + " training loss:\t\t0.027435\n", + " validation loss:\t\t0.139281\n", + " validation accuracy:\t\t96.94 %\n", + "Epoch 28 of 50 took 5.593 s\n", + " training loss:\t\t0.026112\n", + " validation loss:\t\t0.134875\n", + " validation accuracy:\t\t97.11 %\n", + "Epoch 29 of 50 took 5.600 s\n", + " training loss:\t\t0.031704\n", + " validation loss:\t\t0.132471\n", + " validation accuracy:\t\t97.04 %\n", + "Epoch 30 of 50 took 5.596 s\n", + " training loss:\t\t0.027024\n", + " validation loss:\t\t0.137223\n", + " validation accuracy:\t\t97.09 %\n", + "Epoch 31 of 50 took 5.603 s\n", + " training loss:\t\t0.026549\n", + " validation loss:\t\t0.145496\n", + " validation accuracy:\t\t96.97 %\n", + "Epoch 32 of 50 took 5.582 s\n", + " training loss:\t\t0.024687\n", + " validation loss:\t\t0.135853\n", + " validation accuracy:\t\t97.10 %\n", + "Epoch 33 of 50 took 5.599 s\n", + " training loss:\t\t0.023170\n", + " validation loss:\t\t0.130362\n", + " validation accuracy:\t\t97.35 %\n", + "Epoch 34 of 50 took 5.593 s\n", + " training loss:\t\t0.021415\n", + " validation loss:\t\t0.134248\n", + " validation accuracy:\t\t97.24 %\n", + "Epoch 35 of 50 took 5.605 s\n", + " training loss:\t\t0.015472\n", + " validation loss:\t\t0.145639\n", + " validation accuracy:\t\t97.19 %\n", + "Epoch 36 of 50 took 5.600 s\n", + " training loss:\t\t0.019704\n", + " validation loss:\t\t0.158388\n", + " validation accuracy:\t\t96.85 %\n", + "Epoch 37 of 50 took 5.601 s\n", + " training loss:\t\t0.021744\n", + " validation loss:\t\t0.148030\n", + " validation accuracy:\t\t96.96 %\n", + "Epoch 38 of 50 took 5.602 s\n", + " training loss:\t\t0.022655\n", + " validation loss:\t\t0.143620\n", + " validation accuracy:\t\t97.22 %\n", + "Epoch 39 of 50 took 5.610 s\n", + " training loss:\t\t0.017915\n", + " validation loss:\t\t0.143880\n", + " validation accuracy:\t\t97.18 %\n", + "Epoch 40 of 50 took 5.631 s\n", + " training loss:\t\t0.014413\n", + " validation loss:\t\t0.150207\n", + " validation accuracy:\t\t97.29 %\n", + "Epoch 41 of 50 took 5.613 s\n", + " training loss:\t\t0.010106\n", + " validation loss:\t\t0.153401\n", + " validation accuracy:\t\t97.32 %\n", + "Epoch 42 of 50 took 5.615 s\n", + " training loss:\t\t0.016331\n", + " validation loss:\t\t0.152813\n", + " validation accuracy:\t\t97.12 %\n", + "Epoch 43 of 50 took 5.617 s\n", + " training loss:\t\t0.017011\n", + " validation loss:\t\t0.168792\n", + " validation accuracy:\t\t96.96 %\n", + "Epoch 44 of 50 took 5.612 s\n", + " training loss:\t\t0.010532\n", + " validation loss:\t\t0.158618\n", + " validation accuracy:\t\t97.13 %\n", + "Epoch 45 of 50 took 5.618 s\n", + " training loss:\t\t0.009678\n", + " validation loss:\t\t0.171994\n", + " validation accuracy:\t\t97.08 %\n", + "Epoch 46 of 50 took 5.628 s\n", + " training loss:\t\t0.007927\n", + " validation loss:\t\t0.173770\n", + " validation accuracy:\t\t97.03 %\n", + "Epoch 47 of 50 took 5.613 s\n", + " training loss:\t\t0.032664\n", + " validation loss:\t\t0.153496\n", + " validation accuracy:\t\t97.03 %\n", + "Epoch 48 of 50 took 5.613 s\n", + " training loss:\t\t0.016316\n", + " validation loss:\t\t0.162986\n", + " validation accuracy:\t\t97.08 %\n", + "Epoch 49 of 50 took 5.615 s\n", + " training loss:\t\t0.007451\n", + " validation loss:\t\t0.160675\n", + " validation accuracy:\t\t97.15 %\n", + "Epoch 50 of 50 took 5.619 s\n", + " training loss:\t\t0.004508\n", + " validation loss:\t\t0.167268\n", + " validation accuracy:\t\t97.20 %\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "num_batches_train = dataset['num_examples_train'] // batch_size\n", + "num_batches_valid = dataset['num_examples_valid'] // batch_size\n", + "\n", + "print(\"Starting training...\")\n", + "now = time.time()\n", + "\n", + "try:\n", + " for epoch in range(num_epochs):\n", + " batch_train_losses = []\n", + " for b in range(num_batches_train):\n", + " batch_train_loss = iter_train(b)\n", + " batch_train_losses.append(batch_train_loss)\n", + "\n", + " avg_train_loss = np.mean(batch_train_losses)\n", + "\n", + " batch_valid_losses = []\n", + " batch_valid_accuracies = []\n", + " for b in range(num_batches_valid):\n", + " batch_valid_loss, batch_valid_accuracy = iter_valid(b)\n", + " batch_valid_losses.append(batch_valid_loss)\n", + " batch_valid_accuracies.append(batch_valid_accuracy)\n", + "\n", + " avg_valid_loss = np.mean(batch_valid_losses)\n", + " avg_valid_accuracy = np.mean(batch_valid_accuracies)\n", + "\n", + " print(\"Epoch %d of %d took %.3f s\" % (epoch + 1, num_epochs, time.time() - now))\n", + " now = time.time()\n", + " print(\" training loss:\\t\\t%.6f\" % avg_train_loss)\n", + " print(\" validation loss:\\t\\t%.6f\" % avg_valid_loss)\n", + " print(\" validation accuracy:\\t\\t%.2f %%\" % (avg_valid_accuracy * 100))\n", + "except KeyboardInterrupt:\n", + " pass" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From ad5b93d897551fcf1602fe44eebbb6fb903be0eb Mon Sep 17 00:00:00 2001 From: Sander Dieleman Date: Sun, 30 Aug 2015 19:28:02 +0200 Subject: [PATCH 2/2] changed some params to match the paper a bit better --- papers/Highway Networks.ipynb | 396 +++++++++++++++++----------------- 1 file changed, 198 insertions(+), 198 deletions(-) diff --git a/papers/Highway Networks.ipynb b/papers/Highway Networks.ipynb index a891bab..5e31190 100644 --- a/papers/Highway Networks.ipynb +++ b/papers/Highway Networks.ipynb @@ -40,7 +40,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Using gpu device 0: GeForce GTX 980\n" + "Using gpu device 1: Tesla K40c\n" ] } ], @@ -69,7 +69,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2015-08-22 23:00:18-- http://deeplearning.net/data/mnist/mnist.pkl.gz\n", + "--2015-08-30 18:23:32-- http://deeplearning.net/data/mnist/mnist.pkl.gz\n", "Resolving deeplearning.net (deeplearning.net)... 132.204.26.28\n", "Connecting to deeplearning.net (deeplearning.net)|132.204.26.28|:80... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", @@ -166,7 +166,7 @@ "source": [ "**Now we can define a macro function to create a dense highway layer.** Note that it does not take a `num_units` input argument: the number of outputs should always be the same as the number of inputs, so it is redundant.\n", "\n", - "We initialize the biases of the gates to `-4.0` to disable all of them initially. This means all layers will basically pass through the inputs (and gradients) unchanged at the start of training." + "We initialize the biases of the gates to `-4.0` to disable all of them initially. This means all layers will basically pass through the inputs (and gradients) unchanged at the start of training. In the paper, an initial value of `-2.0` is used for the MNIST experiments, but we found this to slow down convergence." ] }, { @@ -248,7 +248,7 @@ "source": [ "def build_model(input_dim, output_dim, batch_size,\n", " num_hidden_units, num_hidden_layers):\n", - " \"\"\"Create a symbolic representation of a neural network with `intput_dim`\n", + " \"\"\"Create a symbolic representation of a neural network with `input_dim`\n", " input nodes, `output_dim` output nodes, `num_hidden_layers` hidden layers\n", " and `num_hidden_units` per hidden layer.\n", " \n", @@ -300,7 +300,7 @@ "source": [ "num_epochs = 50\n", "batch_size = 100\n", - "learning_rate = 0.01\n", + "learning_rate = 0.005\n", "momentum = 0.9\n", "\n", "print(\"Loading data...\")\n", @@ -311,7 +311,7 @@ " input_dim=dataset['input_dim'],\n", " output_dim=dataset['output_dim'],\n", " batch_size=batch_size,\n", - " num_hidden_units=40,\n", + " num_hidden_units=50,\n", " num_hidden_layers=50,\n", ")\n", "\n", @@ -375,206 +375,206 @@ "output_type": "stream", "text": [ "Starting training...\n", - "Epoch 1 of 50 took 5.509 s\n", - " training loss:\t\t0.903543\n", - " validation loss:\t\t0.337289\n", - " validation accuracy:\t\t90.37 %\n", - "Epoch 2 of 50 took 5.529 s\n", - " training loss:\t\t0.314176\n", - " validation loss:\t\t0.232568\n", - " validation accuracy:\t\t92.90 %\n", - "Epoch 3 of 50 took 5.574 s\n", - " training loss:\t\t0.239013\n", - " validation loss:\t\t0.190169\n", - " validation accuracy:\t\t94.42 %\n", - "Epoch 4 of 50 took 5.588 s\n", - " training loss:\t\t0.196143\n", - " validation loss:\t\t0.166020\n", - " validation accuracy:\t\t95.14 %\n", - "Epoch 5 of 50 took 5.602 s\n", - " training loss:\t\t0.166447\n", - " validation loss:\t\t0.151690\n", - " validation accuracy:\t\t95.66 %\n", - "Epoch 6 of 50 took 5.618 s\n", - " training loss:\t\t0.144108\n", - " validation loss:\t\t0.140717\n", - " validation accuracy:\t\t96.02 %\n", - "Epoch 7 of 50 took 5.608 s\n", - " training loss:\t\t0.126167\n", - " validation loss:\t\t0.134176\n", - " validation accuracy:\t\t96.18 %\n", - "Epoch 8 of 50 took 5.614 s\n", - " training loss:\t\t0.111839\n", - " validation loss:\t\t0.129229\n", - " validation accuracy:\t\t96.30 %\n", - "Epoch 9 of 50 took 5.623 s\n", - " training loss:\t\t0.099244\n", - " validation loss:\t\t0.125851\n", - " validation accuracy:\t\t96.53 %\n", - "Epoch 10 of 50 took 5.599 s\n", - " training loss:\t\t0.089900\n", - " validation loss:\t\t0.126563\n", - " validation accuracy:\t\t96.49 %\n", - "Epoch 11 of 50 took 5.611 s\n", - " training loss:\t\t0.082267\n", - " validation loss:\t\t0.124533\n", - " validation accuracy:\t\t96.67 %\n", - "Epoch 12 of 50 took 5.617 s\n", - " training loss:\t\t0.074911\n", - " validation loss:\t\t0.121908\n", - " validation accuracy:\t\t96.75 %\n", - "Epoch 13 of 50 took 5.596 s\n", - " training loss:\t\t0.069080\n", - " validation loss:\t\t0.123284\n", - " validation accuracy:\t\t96.76 %\n", - "Epoch 14 of 50 took 5.600 s\n", - " training loss:\t\t0.062342\n", - " validation loss:\t\t0.125314\n", - " validation accuracy:\t\t96.75 %\n", - "Epoch 15 of 50 took 5.604 s\n", - " training loss:\t\t0.058737\n", - " validation loss:\t\t0.122159\n", - " validation accuracy:\t\t96.89 %\n", - "Epoch 16 of 50 took 5.599 s\n", - " training loss:\t\t0.056600\n", - " validation loss:\t\t0.126851\n", - " validation accuracy:\t\t96.65 %\n", - "Epoch 17 of 50 took 5.601 s\n", - " training loss:\t\t0.049401\n", - " validation loss:\t\t0.127822\n", - " validation accuracy:\t\t96.80 %\n", - "Epoch 18 of 50 took 5.597 s\n", - " training loss:\t\t0.051736\n", - " validation loss:\t\t0.124382\n", - " validation accuracy:\t\t96.93 %\n", - "Epoch 19 of 50 took 5.728 s\n", - " training loss:\t\t0.053274\n", - " validation loss:\t\t0.121892\n", + "Epoch 1 of 50 took 8.463 s\n", + " training loss:\t\t1.241869\n", + " validation loss:\t\t0.529914\n", + " validation accuracy:\t\t87.34 %\n", + "Epoch 2 of 50 took 8.474 s\n", + " training loss:\t\t0.439579\n", + " validation loss:\t\t0.330878\n", + " validation accuracy:\t\t90.30 %\n", + "Epoch 3 of 50 took 8.498 s\n", + " training loss:\t\t0.330963\n", + " validation loss:\t\t0.270475\n", + " validation accuracy:\t\t91.88 %\n", + "Epoch 4 of 50 took 8.506 s\n", + " training loss:\t\t0.274753\n", + " validation loss:\t\t0.231958\n", + " validation accuracy:\t\t93.08 %\n", + "Epoch 5 of 50 took 8.503 s\n", + " training loss:\t\t0.236217\n", + " validation loss:\t\t0.206835\n", + " validation accuracy:\t\t93.90 %\n", + "Epoch 6 of 50 took 8.509 s\n", + " training loss:\t\t0.208123\n", + " validation loss:\t\t0.188295\n", + " validation accuracy:\t\t94.54 %\n", + "Epoch 7 of 50 took 8.517 s\n", + " training loss:\t\t0.186583\n", + " validation loss:\t\t0.171706\n", + " validation accuracy:\t\t95.12 %\n", + "Epoch 8 of 50 took 8.523 s\n", + " training loss:\t\t0.168530\n", + " validation loss:\t\t0.158668\n", + " validation accuracy:\t\t95.47 %\n", + "Epoch 9 of 50 took 8.503 s\n", + " training loss:\t\t0.152036\n", + " validation loss:\t\t0.149293\n", + " validation accuracy:\t\t95.69 %\n", + "Epoch 10 of 50 took 8.521 s\n", + " training loss:\t\t0.139339\n", + " validation loss:\t\t0.142494\n", + " validation accuracy:\t\t95.85 %\n", + "Epoch 11 of 50 took 8.517 s\n", + " training loss:\t\t0.127803\n", + " validation loss:\t\t0.136691\n", + " validation accuracy:\t\t95.96 %\n", + "Epoch 12 of 50 took 8.510 s\n", + " training loss:\t\t0.117640\n", + " validation loss:\t\t0.132784\n", + " validation accuracy:\t\t96.11 %\n", + "Epoch 13 of 50 took 8.519 s\n", + " training loss:\t\t0.107974\n", + " validation loss:\t\t0.130798\n", + " validation accuracy:\t\t96.21 %\n", + "Epoch 14 of 50 took 8.513 s\n", + " training loss:\t\t0.099903\n", + " validation loss:\t\t0.127288\n", + " validation accuracy:\t\t96.28 %\n", + "Epoch 15 of 50 took 8.513 s\n", + " training loss:\t\t0.093034\n", + " validation loss:\t\t0.123553\n", + " validation accuracy:\t\t96.43 %\n", + "Epoch 16 of 50 took 8.516 s\n", + " training loss:\t\t0.090243\n", + " validation loss:\t\t0.123225\n", + " validation accuracy:\t\t96.55 %\n", + "Epoch 17 of 50 took 8.506 s\n", + " training loss:\t\t0.082680\n", + " validation loss:\t\t0.120509\n", + " validation accuracy:\t\t96.61 %\n", + "Epoch 18 of 50 took 8.513 s\n", + " training loss:\t\t0.076955\n", + " validation loss:\t\t0.122340\n", + " validation accuracy:\t\t96.51 %\n", + "Epoch 19 of 50 took 8.507 s\n", + " training loss:\t\t0.074587\n", + " validation loss:\t\t0.119997\n", + " validation accuracy:\t\t96.77 %\n", + "Epoch 20 of 50 took 8.522 s\n", + " training loss:\t\t0.069306\n", + " validation loss:\t\t0.116630\n", + " validation accuracy:\t\t96.77 %\n", + "Epoch 21 of 50 took 8.519 s\n", + " training loss:\t\t0.068775\n", + " validation loss:\t\t0.119293\n", " validation accuracy:\t\t96.90 %\n", - "Epoch 20 of 50 took 5.606 s\n", - " training loss:\t\t0.041346\n", - " validation loss:\t\t0.132626\n", - " validation accuracy:\t\t96.83 %\n", - "Epoch 21 of 50 took 5.606 s\n", - " training loss:\t\t0.042719\n", - " validation loss:\t\t0.136224\n", - " validation accuracy:\t\t96.62 %\n", - "Epoch 22 of 50 took 5.596 s\n", - " training loss:\t\t0.036241\n", - " validation loss:\t\t0.128471\n", - " validation accuracy:\t\t96.92 %\n", - "Epoch 23 of 50 took 5.625 s\n", - " training loss:\t\t0.036981\n", - " validation loss:\t\t0.134090\n", - " validation accuracy:\t\t96.96 %\n", - "Epoch 24 of 50 took 5.596 s\n", - " training loss:\t\t0.036138\n", - " validation loss:\t\t0.131586\n", - " validation accuracy:\t\t96.97 %\n", - "Epoch 25 of 50 took 5.590 s\n", - " training loss:\t\t0.034088\n", - " validation loss:\t\t0.129220\n", - " validation accuracy:\t\t97.04 %\n", - "Epoch 26 of 50 took 5.596 s\n", - " training loss:\t\t0.030038\n", - " validation loss:\t\t0.139834\n", - " validation accuracy:\t\t96.84 %\n", - "Epoch 27 of 50 took 5.595 s\n", - " training loss:\t\t0.027435\n", - " validation loss:\t\t0.139281\n", - " validation accuracy:\t\t96.94 %\n", - "Epoch 28 of 50 took 5.593 s\n", - " training loss:\t\t0.026112\n", - " validation loss:\t\t0.134875\n", - " validation accuracy:\t\t97.11 %\n", - "Epoch 29 of 50 took 5.600 s\n", - " training loss:\t\t0.031704\n", - " validation loss:\t\t0.132471\n", + "Epoch 22 of 50 took 8.520 s\n", + " training loss:\t\t0.060607\n", + " validation loss:\t\t0.114335\n", " validation accuracy:\t\t97.04 %\n", - "Epoch 30 of 50 took 5.596 s\n", - " training loss:\t\t0.027024\n", - " validation loss:\t\t0.137223\n", + "Epoch 23 of 50 took 8.520 s\n", + " training loss:\t\t0.061950\n", + " validation loss:\t\t0.117707\n", + " validation accuracy:\t\t96.92 %\n", + "Epoch 24 of 50 took 8.516 s\n", + " training loss:\t\t0.055509\n", + " validation loss:\t\t0.119587\n", + " validation accuracy:\t\t96.92 %\n", + "Epoch 25 of 50 took 8.513 s\n", + " training loss:\t\t0.053691\n", + " validation loss:\t\t0.115627\n", " validation accuracy:\t\t97.09 %\n", - "Epoch 31 of 50 took 5.603 s\n", - " training loss:\t\t0.026549\n", - " validation loss:\t\t0.145496\n", - " validation accuracy:\t\t96.97 %\n", - "Epoch 32 of 50 took 5.582 s\n", - " training loss:\t\t0.024687\n", - " validation loss:\t\t0.135853\n", + "Epoch 26 of 50 took 8.529 s\n", + " training loss:\t\t0.051308\n", + " validation loss:\t\t0.119400\n", + " validation accuracy:\t\t97.15 %\n", + "Epoch 27 of 50 took 8.507 s\n", + " training loss:\t\t0.050237\n", + " validation loss:\t\t0.118064\n", + " validation accuracy:\t\t97.14 %\n", + "Epoch 28 of 50 took 8.507 s\n", + " training loss:\t\t0.049942\n", + " validation loss:\t\t0.119468\n", + " validation accuracy:\t\t97.07 %\n", + "Epoch 29 of 50 took 8.518 s\n", + " training loss:\t\t0.043013\n", + " validation loss:\t\t0.118137\n", + " validation accuracy:\t\t97.11 %\n", + "Epoch 30 of 50 took 8.506 s\n", + " training loss:\t\t0.043521\n", + " validation loss:\t\t0.123761\n", + " validation accuracy:\t\t97.07 %\n", + "Epoch 31 of 50 took 8.470 s\n", + " training loss:\t\t0.038424\n", + " validation loss:\t\t0.120082\n", " validation accuracy:\t\t97.10 %\n", - "Epoch 33 of 50 took 5.599 s\n", - " training loss:\t\t0.023170\n", - " validation loss:\t\t0.130362\n", - " validation accuracy:\t\t97.35 %\n", - "Epoch 34 of 50 took 5.593 s\n", - " training loss:\t\t0.021415\n", - " validation loss:\t\t0.134248\n", + "Epoch 32 of 50 took 8.511 s\n", + " training loss:\t\t0.035319\n", + " validation loss:\t\t0.123719\n", + " validation accuracy:\t\t97.07 %\n", + "Epoch 33 of 50 took 8.523 s\n", + " training loss:\t\t0.033492\n", + " validation loss:\t\t0.127172\n", " validation accuracy:\t\t97.24 %\n", - "Epoch 35 of 50 took 5.605 s\n", - " training loss:\t\t0.015472\n", - " validation loss:\t\t0.145639\n", - " validation accuracy:\t\t97.19 %\n", - "Epoch 36 of 50 took 5.600 s\n", - " training loss:\t\t0.019704\n", - " validation loss:\t\t0.158388\n", + "Epoch 34 of 50 took 8.514 s\n", + " training loss:\t\t0.032029\n", + " validation loss:\t\t0.134384\n", + " validation accuracy:\t\t96.87 %\n", + "Epoch 35 of 50 took 8.514 s\n", + " training loss:\t\t0.036389\n", + " validation loss:\t\t0.127195\n", + " validation accuracy:\t\t97.10 %\n", + "Epoch 36 of 50 took 8.514 s\n", + " training loss:\t\t0.029351\n", + " validation loss:\t\t0.132363\n", + " validation accuracy:\t\t97.06 %\n", + "Epoch 37 of 50 took 8.520 s\n", + " training loss:\t\t0.029062\n", + " validation loss:\t\t0.134932\n", + " validation accuracy:\t\t97.07 %\n", + "Epoch 38 of 50 took 8.516 s\n", + " training loss:\t\t0.038298\n", + " validation loss:\t\t0.145781\n", " validation accuracy:\t\t96.85 %\n", - "Epoch 37 of 50 took 5.601 s\n", - " training loss:\t\t0.021744\n", - " validation loss:\t\t0.148030\n", - " validation accuracy:\t\t96.96 %\n", - "Epoch 38 of 50 took 5.602 s\n", - " training loss:\t\t0.022655\n", - " validation loss:\t\t0.143620\n", + "Epoch 39 of 50 took 8.505 s\n", + " training loss:\t\t0.032814\n", + " validation loss:\t\t0.133435\n", + " validation accuracy:\t\t96.99 %\n", + "Epoch 40 of 50 took 8.519 s\n", + " training loss:\t\t0.027054\n", + " validation loss:\t\t0.148125\n", + " validation accuracy:\t\t97.04 %\n", + "Epoch 41 of 50 took 8.518 s\n", + " training loss:\t\t0.029643\n", + " validation loss:\t\t0.137073\n", + " validation accuracy:\t\t96.88 %\n", + "Epoch 42 of 50 took 8.521 s\n", + " training loss:\t\t0.022993\n", + " validation loss:\t\t0.136357\n", + " validation accuracy:\t\t97.14 %\n", + "Epoch 43 of 50 took 8.524 s\n", + " training loss:\t\t0.021812\n", + " validation loss:\t\t0.138383\n", " validation accuracy:\t\t97.22 %\n", - "Epoch 39 of 50 took 5.610 s\n", - " training loss:\t\t0.017915\n", - " validation loss:\t\t0.143880\n", - " validation accuracy:\t\t97.18 %\n", - "Epoch 40 of 50 took 5.631 s\n", - " training loss:\t\t0.014413\n", - " validation loss:\t\t0.150207\n", - " validation accuracy:\t\t97.29 %\n", - "Epoch 41 of 50 took 5.613 s\n", - " training loss:\t\t0.010106\n", - " validation loss:\t\t0.153401\n", - " validation accuracy:\t\t97.32 %\n", - "Epoch 42 of 50 took 5.615 s\n", - " training loss:\t\t0.016331\n", - " validation loss:\t\t0.152813\n", - " validation accuracy:\t\t97.12 %\n", - "Epoch 43 of 50 took 5.617 s\n", - " training loss:\t\t0.017011\n", - " validation loss:\t\t0.168792\n", - " validation accuracy:\t\t96.96 %\n", - "Epoch 44 of 50 took 5.612 s\n", - " training loss:\t\t0.010532\n", - " validation loss:\t\t0.158618\n", + "Epoch 44 of 50 took 8.515 s\n", + " training loss:\t\t0.040693\n", + " validation loss:\t\t0.130088\n", " validation accuracy:\t\t97.13 %\n", - "Epoch 45 of 50 took 5.618 s\n", - " training loss:\t\t0.009678\n", - " validation loss:\t\t0.171994\n", - " validation accuracy:\t\t97.08 %\n", - "Epoch 46 of 50 took 5.628 s\n", - " training loss:\t\t0.007927\n", - " validation loss:\t\t0.173770\n", - " validation accuracy:\t\t97.03 %\n", - "Epoch 47 of 50 took 5.613 s\n", - " training loss:\t\t0.032664\n", - " validation loss:\t\t0.153496\n", - " validation accuracy:\t\t97.03 %\n", - "Epoch 48 of 50 took 5.613 s\n", - " training loss:\t\t0.016316\n", - " validation loss:\t\t0.162986\n", - " validation accuracy:\t\t97.08 %\n", - "Epoch 49 of 50 took 5.615 s\n", - " training loss:\t\t0.007451\n", - " validation loss:\t\t0.160675\n", - " validation accuracy:\t\t97.15 %\n", - "Epoch 50 of 50 took 5.619 s\n", - " training loss:\t\t0.004508\n", - " validation loss:\t\t0.167268\n", - " validation accuracy:\t\t97.20 %\n" + "Epoch 45 of 50 took 8.516 s\n", + " training loss:\t\t0.028012\n", + " validation loss:\t\t0.133577\n", + " validation accuracy:\t\t97.20 %\n", + "Epoch 46 of 50 took 8.526 s\n", + " training loss:\t\t0.022198\n", + " validation loss:\t\t0.140025\n", + " validation accuracy:\t\t97.22 %\n", + "Epoch 47 of 50 took 8.511 s\n", + " training loss:\t\t0.030115\n", + " validation loss:\t\t0.144627\n", + " validation accuracy:\t\t97.02 %\n", + "Epoch 48 of 50 took 8.506 s\n", + " training loss:\t\t0.021628\n", + " validation loss:\t\t0.141951\n", + " validation accuracy:\t\t97.11 %\n", + "Epoch 49 of 50 took 8.521 s\n", + " training loss:\t\t0.027799\n", + " validation loss:\t\t0.140850\n", + " validation accuracy:\t\t96.99 %\n", + "Epoch 50 of 50 took 8.517 s\n", + " training loss:\t\t0.020513\n", + " validation loss:\t\t0.155682\n", + " validation accuracy:\t\t96.76 %\n" ] } ],