From 506c4d94179afab4eeb5b0da39b65239e40e25fb Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Thu, 3 Sep 2015 17:49:11 -0700 Subject: [PATCH] Add language model ipython notebook from tutorial --- .../coco_caption/Caffe language model.ipynb | 617 ++++++++++++++++++ 1 file changed, 617 insertions(+) create mode 100644 examples/coco_caption/Caffe language model.ipynb diff --git a/examples/coco_caption/Caffe language model.ipynb b/examples/coco_caption/Caffe language model.ipynb new file mode 100644 index 00000000000..30d2c494a39 --- /dev/null +++ b/examples/coco_caption/Caffe language model.ipynb @@ -0,0 +1,617 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import random\n", + "\n", + "import sys\n", + "sys.path.append('./python')\n", + "import caffe\n", + "\n", + "sys.path.append('./examples/coco_caption')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\r\n", + "a\r\n", + "on\r\n", + "of\r\n", + "the\r\n", + "in\r\n", + "with\r\n", + "and\r\n", + "is\r\n", + "man\r\n" + ] + } + ], + "source": [ + "!head examples/coco_caption/h5_data/buffer_100/vocabulary.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8801\n" + ] + } + ], + "source": [ + "vocabulary = [''] + [line.strip() for line in\n", + " open('examples/coco_caption/h5_data/buffer_100/vocabulary.txt').readlines()]\n", + "print len(vocabulary)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 1, 8801)\n" + ] + } + ], + "source": [ + "iter_num = 110000\n", + "net = caffe.Net('./examples/coco_caption/lstm_lm.deploy.prototxt',\n", + " './examples/coco_caption/lstm_lm_iter_%d.caffemodel' % iter_num, caffe.TEST)\n", + "print net.blobs['probs'].data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def predict_single_word(net, previous_word, output='probs'):\n", + " cont = 0 if previous_word == 0 else 1\n", + " cont_input = np.array([cont])\n", + " word_input = np.array([previous_word])\n", + " net.forward(cont_sentence=cont_input, input_sentence=word_input)\n", + " output_preds = net.blobs[output].data[0, 0, :]\n", + " return output_preds" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "first_word_dist = predict_single_word(net, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "top_preds = np.argsort(-1 * first_word_dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 2 14 5 13 64 77 30 18 93 142]\n", + "['a', 'two', 'the', 'an', 'there', 'three', 'some', 'people', 'several', 'this']\n" + ] + } + ], + "source": [ + "print top_preds[:10]\n", + "print [vocabulary[index] for index in top_preds[:10]]" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['people', 'men', 'women', 'giraffes', 'zebras', 'young', 'cats', 'elephants', 'horses', 'children']\n" + ] + } + ], + "source": [ + "second_word_dist = predict_single_word(net, vocabulary.index('two'))\n", + "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['standing', 'are', 'in', 'stand', 'walking', 'and', 'eating', 'that', 'walk', 'with']\n" + ] + } + ], + "source": [ + "third_word_dist = predict_single_word(net, vocabulary.index('giraffes'))\n", + "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['leaves', 'from', 'grass', 'hay', 'out', 'some', 'in', 'food', 'off', 'a']\n" + ] + } + ], + "source": [ + "third_word_dist = predict_single_word(net, vocabulary.index('eating'))\n", + "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def softmax(softmax_inputs, temp):\n", + " shifted_inputs = softmax_inputs - softmax_inputs.max()\n", + " exp_outputs = np.exp(temp * shifted_inputs)\n", + " exp_outputs_sum = exp_outputs.sum()\n", + " if np.isnan(exp_outputs_sum):\n", + " return exp_outputs * float('nan')\n", + " assert exp_outputs_sum > 0\n", + " if np.isinf(exp_outputs_sum):\n", + " return np.zeros_like(exp_outputs)\n", + " eps_sum = 1e-20\n", + " return exp_outputs / max(exp_outputs_sum, eps_sum)\n", + "\n", + "def random_choice_from_probs(softmax_inputs, temp=1):\n", + " # temperature of infinity == take the max\n", + " if temp == float('inf'):\n", + " return np.argmax(softmax_inputs)\n", + " probs = softmax(softmax_inputs, temp)\n", + " r = random.random()\n", + " cum_sum = 0.\n", + " for i, p in enumerate(probs):\n", + " cum_sum += p\n", + " if cum_sum >= r: return i\n", + " return 1 # return UNK?" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def generate_sentence(net, temp=float('inf'), output='predict', max_words=50):\n", + " cont_input = np.array([0])\n", + " word_input = np.array([0])\n", + " sentence = []\n", + " while len(sentence) < max_words and (not sentence or sentence[-1] != 0):\n", + " net.forward(cont_sentence=cont_input, input_sentence=word_input)\n", + " output_preds = net.blobs[output].data[0, 0, :]\n", + " sentence.append(random_choice_from_probs(output_preds, temp=temp))\n", + " cont_input[0] = 1\n", + " word_input[0] = sentence[-1]\n", + " return sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n", + "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n", + "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 22, 9, 294, 7, 2, 178, 113, 11, 87, 905, 0]\n", + "['a', 'woman', 'is', 'posing', 'with', 'a', 'cell', 'phone', 'to', 'her', 'ear', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=1.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 28, 26, 2, 38, 209, 3, 2, 38, 152, 0]\n", + "['a', 'person', 'holding', 'a', 'tennis', 'racket', 'on', 'a', 'tennis', 'court', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=1.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]\n", + "['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=1.5)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 33, 4, 18, 12, 106, 2, 23, 7, 60, 0]\n", + "['a', 'group', 'of', 'people', 'sitting', 'around', 'a', 'table', 'with', 'food', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=1.5)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 6, 2, 261, 8, 217, 16, 6, 2, 43, 0]\n", + "['a', 'man', 'in', 'a', 'suit', 'and', 'tie', 'standing', 'in', 'a', 'room', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=3.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]\n", + "['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=3.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n", + "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=10.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1993, 1074, 86, 6, 40, 4, 2, 126, 0]\n", + "['staircase', 'laid', 'out', 'in', 'front', 'of', 'a', 'window', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=1.0)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 28, 3, 2, 113, 46, 2, 129, 0]\n", + "['a', 'person', 'on', 'a', 'phone', 'riding', 'a', 'car', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=0.8)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 16, 60, 6, 136, 192, 7, 641, 16, 20, 11, 27, 0]\n", + "['a', 'standing', 'food', 'in', 'each', 'hand', 'with', 'cattle', 'standing', 'next', 'to', 'it', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=0.8)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[28, 236, 1042, 7, 69, 1257, 487, 1769, 0]\n", + "['person', 'taking', 'noodles', 'with', 'other', 'homemade', 'birthday', 'cereal', '']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=0.6)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[5623, 1087, 15, 6888, 472, 361, 8634, 8, 7241, 3, 77, 299, 935, 1296, 15, 12, 5165, 2867, 3979, 743, 4991, 4470, 640, 9, 259, 2308, 4386, 2552, 3797, 2448, 15, 3617, 5364, 4267, 4549, 8086, 176, 2529, 6434, 5445, 370, 7959, 5672, 1742, 4041, 4258, 1153, 8, 610, 2044]\n", + "['chilli', 'frosting', ',', 'medley', 'salad', 'items', 'sideboard', 'and', 'garnishes', 'on', 'three', 'colorful', 'gold', 'desserts', ',', 'sitting', 'knifes', 'need', 'workspace', 'where', 'exchanging', 'hoses', 'left', 'is', 'pink', 'clearing', 'obstacles', 'vandalized', 'idly', 'afternoon', ',', 'halloween', 'rich', 'fixed', 'aid', 'advertise', 'light', 'times', 'delicate', 'dealership', 'like', 'snowsuits', 'florida', 'than', 'ornamental', 'dr', 'curtains', 'and', 'multiple', 'electrical']\n" + ] + } + ], + "source": [ + "sentence = generate_sentence(net, temp=0.5)\n", + "print sentence\n", + "print [vocabulary[index] for index in sentence]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}